提交 ffc68688 编写于 作者: J JepsonWong

add dataloader for seq2seq, test=develop

上级 87e87ae7
...@@ -22,6 +22,7 @@ import os ...@@ -22,6 +22,7 @@ import os
import io import io
import sys import sys
import numpy as np import numpy as np
import paddle.fluid as fluid
Py3 = sys.version_info[0] == 3 Py3 = sys.version_info[0] == 3
...@@ -135,6 +136,53 @@ def raw_data(src_lang, ...@@ -135,6 +136,53 @@ def raw_data(src_lang,
(src_vocab, tar_vocab) (src_vocab, tar_vocab)
def raw_train_data(src_lang, tar_lang, vocab_prefix, train_prefix, max_sequence_len=50):
src_vocab_file = vocab_prefix + "." + src_lang
tar_vocab_file = vocab_prefix + "." + tar_lang
src_train_file = train_prefix + "." + src_lang
tar_train_file = train_prefix + "." + tar_lang
src_vocab = _build_vocab(src_vocab_file)
tar_vocab = _build_vocab(tar_vocab_file)
train_src, train_tar = _para_file_to_ids(src_train_file, tar_train_file, \
src_vocab, tar_vocab)
train_src, train_tar = filter_len(
train_src, train_tar, max_sequence_len=max_sequence_len)
return (train_src, train_tar)
def raw_eval_data(src_lang, tar_lang, vocab_prefix, eval_prefix, max_sequence_len=50):
src_vocab_file = vocab_prefix + "." + src_lang
tar_vocab_file = vocab_prefix + "." + tar_lang
src_eval_file = eval_prefix + "." + src_lang
tar_eval_file = eval_prefix + "." + tar_lang
src_vocab = _build_vocab(src_vocab_file)
tar_vocab = _build_vocab(tar_vocab_file)
eval_src, eval_tar = _para_file_to_ids(src_eval_file, tar_eval_file, \
src_vocab, tar_vocab)
return (eval_src, eval_tar)
def raw_test_data(src_lang, tar_lang, vocab_prefix, test_prefix, max_sequence_len=50):
src_vocab_file = vocab_prefix + "." + src_lang
tar_vocab_file = vocab_prefix + "." + tar_lang
src_test_file = test_prefix + "." + src_lang
tar_test_file = test_prefix + "." + tar_lang
src_vocab = _build_vocab(src_vocab_file)
tar_vocab = _build_vocab(tar_vocab_file)
test_src, test_tar = _para_file_to_ids(src_test_file, tar_test_file, \
src_vocab, tar_vocab)
return (test_src, test_tar)
def raw_mono_data(vocab_file, file_path): def raw_mono_data(vocab_file, file_path):
src_vocab = _build_vocab(vocab_file) src_vocab = _build_vocab(vocab_file)
...@@ -144,6 +192,18 @@ def raw_mono_data(vocab_file, file_path): ...@@ -144,6 +192,18 @@ def raw_mono_data(vocab_file, file_path):
return (test_src, test_tar) return (test_src, test_tar)
def prepare_input(batch):
src_ids, src_mask, tar_ids, tar_mask = batch
res = {}
src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1]))
in_tar = tar_ids[:, :-1]
label_tar = tar_ids[:, 1:]
in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1]))
label_tar = label_tar.reshape(
(label_tar.shape[0], label_tar.shape[1], 1))
inputs = (src_ids, in_tar, label_tar, src_mask, tar_mask, np.sum(tar_mask))
return inputs
def get_data_iter(raw_data, def get_data_iter(raw_data,
batch_size, batch_size,
...@@ -218,3 +278,97 @@ def get_data_iter(raw_data, ...@@ -218,3 +278,97 @@ def get_data_iter(raw_data,
src_ids, src_mask = to_pad_np(src_cache, source=True) src_ids, src_mask = to_pad_np(src_cache, source=True)
tar_ids, tar_mask = to_pad_np(tar_cache) tar_ids, tar_mask = to_pad_np(tar_cache)
yield (src_ids, src_mask, tar_ids, tar_mask) yield (src_ids, src_mask, tar_ids, tar_mask)
def get_reader(batch_size,
src_lang,
tar_lang,
vocab_prefix,
data_prefix,
max_len=50,
reader_mode='train',
mode='train',
enable_ce=False,
cache_num=20):
def get_data_reader():
def get_batch_data():
if reader_mode == 'train':
raw_data = raw_train_data(src_lang, tar_lang, vocab_prefix, data_prefix, max_sequence_len=max_len)
elif reader_mode == 'valid':
raw_data = raw_eval_data(src_lang, tar_lang, vocab_prefix, data_prefix)
else:
raw_data = raw_test_data(src_lang, tar_lang, vocab_prefix, data_prefix)
src_data, tar_data = raw_data
data_len = len(src_data)
index = np.arange(data_len)
if mode == "train" and not enable_ce:
np.random.shuffle(index)
def to_pad_np(data, source=False):
max_len = 0
bs = min(batch_size, len(data))
for ele in data:
if len(ele) > max_len:
max_len = len(ele)
ids = np.ones((bs, max_len), dtype='int64') * 2
mask = np.zeros((bs), dtype='int32')
for i, ele in enumerate(data):
ids[i, :len(ele)] = ele
if not source:
mask[i] = len(ele) - 1
else:
mask[i] = len(ele)
return ids, mask
b_src = []
nonlocal cache_num
if mode != "train":
cache_num = 1
for j in range(data_len):
if len(b_src) == batch_size * cache_num:
# build batch size
# sort
if mode == 'infer':
new_cache = b_src
else:
new_cache = sorted(b_src, key=lambda k: len(k[0]))
for i in range(cache_num):
batch_data = new_cache[i * batch_size:(i + 1) * batch_size]
src_cache = [w[0] for w in batch_data]
tar_cache = [w[1] for w in batch_data]
src_ids, src_mask = to_pad_np(src_cache, source=True)
tar_ids, tar_mask = to_pad_np(tar_cache)
yield prepare_input((src_ids, src_mask, tar_ids, tar_mask))
b_src = []
b_src.append((src_data[index[j]], tar_data[index[j]]))
if len(b_src) == batch_size * cache_num or mode == 'infer':
if mode == 'infer':
new_cache = b_src
else:
new_cache = sorted(b_src, key=lambda k: len(k[0]))
for i in range(cache_num):
batch_end = min(len(new_cache), (i + 1) * batch_size)
batch_data = new_cache[i * batch_size: batch_end]
src_cache = [w[0] for w in batch_data]
tar_cache = [w[1] for w in batch_data]
src_ids, src_mask = to_pad_np(src_cache, source=True)
tar_ids, tar_mask = to_pad_np(tar_cache)
yield prepare_input((src_ids, src_mask, tar_ids, tar_mask))
return get_batch_data
data_reader = get_data_reader()
return data_reader
...@@ -102,39 +102,31 @@ def main(): ...@@ -102,39 +102,31 @@ def main():
src_lang = args.src_lang src_lang = args.src_lang
tar_lang = args.tar_lang tar_lang = args.tar_lang
print("begin to load data") print("begin to load data")
raw_data = reader.raw_data(src_lang, tar_lang, vocab_prefix, train_data_iter = fluid.io.DataLoader.from_generator(capacity=32, use_double_buffer=True, iterable=True, return_list=True, use_multiprocess=True)
train_data_prefix, eval_data_prefix, valid_data_iter = fluid.io.DataLoader.from_generator(capacity=32, use_double_buffer=True, iterable=True, return_list=True, use_multiprocess=True)
test_data_prefix, args.max_len) test_data_iter = fluid.io.DataLoader.from_generator(capacity=32, use_double_buffer=True, iterable=True, return_list=True, use_multiprocess=True)
train_reader = reader.get_reader(batch_size, src_lang, tar_lang, vocab_prefix, train_data_prefix, max_len=args.max_len, reader_mode='train', enable_ce=args.enable_ce, cache_num=20)
train_data_iter.set_batch_generator(train_reader, place)
valid_reader = reader.get_reader(batch_size, src_lang, tar_lang, vocab_prefix, eval_data_prefix, reader_mode='valid', mode='eval', cache_num=20)
valid_data_iter.set_batch_generator(valid_reader, place)
test_reader = reader.get_reader(batch_size, src_lang, tar_lang, vocab_prefix, test_data_prefix, reader_mode='test', mode='eval', cache_num=20)
test_data_iter.set_batch_generator(test_reader, place)
print("finished load data") print("finished load data")
train_data, valid_data, test_data, _ = raw_data
def prepare_input(batch, epoch_id=0):
src_ids, src_mask, tar_ids, tar_mask = batch
res = {}
src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1]))
in_tar = tar_ids[:, :-1]
label_tar = tar_ids[:, 1:]
in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1]))
label_tar = label_tar.reshape(
(label_tar.shape[0], label_tar.shape[1], 1))
inputs = [src_ids, in_tar, label_tar, src_mask, tar_mask]
return inputs, np.sum(tar_mask)
# get train epoch size # get train epoch size
def eval(data, epoch_id=0): def eval(eval_data_iter, epoch_id=0):
model.eval() model.eval()
eval_data_iter = reader.get_data_iter(data, batch_size, mode='eval')
total_loss = 0.0 total_loss = 0.0
word_count = 0.0 word_count = 0.0
for batch_id, batch in enumerate(eval_data_iter): for batch_id, batch in enumerate(eval_data_iter):
input_data_feed, word_num = prepare_input( input_data_feed1, input_data_feed2, input_data_feed3, input_data_feed4, input_data_feed5, word_num = batch
batch, epoch_id) input_data_feed = [input_data_feed1, input_data_feed2, input_data_feed3, input_data_feed4, input_data_feed5]
loss = model(input_data_feed) loss = model(input_data_feed)
total_loss += loss * batch_size total_loss += loss * batch_size
word_count += word_num word_count += word_num
ppl = np.exp(total_loss.numpy() / word_count) ppl = np.exp(total_loss.numpy() / word_count.numpy())
model.train() model.train()
return ppl return ppl
...@@ -142,19 +134,14 @@ def main(): ...@@ -142,19 +134,14 @@ def main():
for epoch_id in range(max_epoch): for epoch_id in range(max_epoch):
model.train() model.train()
start_time = time.time() start_time = time.time()
if args.enable_ce:
train_data_iter = reader.get_data_iter(
train_data, batch_size, enable_ce=True)
else:
train_data_iter = reader.get_data_iter(train_data, batch_size)
total_loss = 0 total_loss = 0
word_count = 0.0 word_count = 0.0
batch_times = [] batch_times = []
for batch_id, batch in enumerate(train_data_iter): for batch_id, batch in enumerate(train_data_iter):
batch_start_time = time.time() batch_start_time = time.time()
input_data_feed, word_num = prepare_input( input_data_feed1, input_data_feed2, input_data_feed3, input_data_feed4, input_data_feed5, word_num = batch
batch, epoch_id=epoch_id) input_data_feed = [input_data_feed1, input_data_feed2, input_data_feed3, input_data_feed4, input_data_feed5]
word_count += word_num word_count += word_num
loss = model(input_data_feed) loss = model(input_data_feed)
# print(loss.numpy()[0]) # print(loss.numpy()[0])
...@@ -169,7 +156,7 @@ def main(): ...@@ -169,7 +156,7 @@ def main():
if batch_id > 0 and batch_id % 100 == 0: if batch_id > 0 and batch_id % 100 == 0:
print("-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f" % print("-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f" %
(epoch_id, batch_id, batch_time, (epoch_id, batch_id, batch_time,
np.exp(total_loss.numpy() / word_count))) np.exp(total_loss.numpy() / word_count.numpy())))
total_loss = 0.0 total_loss = 0.0
word_count = 0.0 word_count = 0.0
...@@ -185,9 +172,9 @@ def main(): ...@@ -185,9 +172,9 @@ def main():
print("begin to save", dir_name) print("begin to save", dir_name)
paddle.fluid.save_dygraph(model.state_dict(), dir_name) paddle.fluid.save_dygraph(model.state_dict(), dir_name)
print("save finished") print("save finished")
dev_ppl = eval(valid_data) dev_ppl = eval(valid_data_iter)
print("dev ppl", dev_ppl) print("dev ppl", dev_ppl)
test_ppl = eval(test_data) test_ppl = eval(test_data_iter)
print("test ppl", test_ppl) print("test ppl", test_ppl)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册