From ffc686886467395d56b4b5a4999751452e21047b Mon Sep 17 00:00:00 2001 From: JepsonWong <2013000149@qq.com> Date: Thu, 5 Mar 2020 17:34:38 +0000 Subject: [PATCH] add dataloader for seq2seq, test=develop --- dygraph/seq2seq/reader.py | 154 ++++++++++++++++++++++++++++++++++++++ dygraph/seq2seq/train.py | 51 +++++-------- 2 files changed, 173 insertions(+), 32 deletions(-) mode change 100755 => 100644 dygraph/seq2seq/reader.py mode change 100755 => 100644 dygraph/seq2seq/train.py diff --git a/dygraph/seq2seq/reader.py b/dygraph/seq2seq/reader.py old mode 100755 new mode 100644 index 4f275607..f07957d7 --- a/dygraph/seq2seq/reader.py +++ b/dygraph/seq2seq/reader.py @@ -22,6 +22,7 @@ import os import io import sys import numpy as np +import paddle.fluid as fluid Py3 = sys.version_info[0] == 3 @@ -135,6 +136,53 @@ def raw_data(src_lang, (src_vocab, tar_vocab) +def raw_train_data(src_lang, tar_lang, vocab_prefix, train_prefix, max_sequence_len=50): + src_vocab_file = vocab_prefix + "." + src_lang + tar_vocab_file = vocab_prefix + "." + tar_lang + + src_train_file = train_prefix + "." + src_lang + tar_train_file = train_prefix + "." + tar_lang + + src_vocab = _build_vocab(src_vocab_file) + tar_vocab = _build_vocab(tar_vocab_file) + + train_src, train_tar = _para_file_to_ids(src_train_file, tar_train_file, \ + src_vocab, tar_vocab) + train_src, train_tar = filter_len( + train_src, train_tar, max_sequence_len=max_sequence_len) + + return (train_src, train_tar) + +def raw_eval_data(src_lang, tar_lang, vocab_prefix, eval_prefix, max_sequence_len=50): + src_vocab_file = vocab_prefix + "." + src_lang + tar_vocab_file = vocab_prefix + "." + tar_lang + + src_eval_file = eval_prefix + "." + src_lang + tar_eval_file = eval_prefix + "." + tar_lang + + src_vocab = _build_vocab(src_vocab_file) + tar_vocab = _build_vocab(tar_vocab_file) + + eval_src, eval_tar = _para_file_to_ids(src_eval_file, tar_eval_file, \ + src_vocab, tar_vocab) + + return (eval_src, eval_tar) + +def raw_test_data(src_lang, tar_lang, vocab_prefix, test_prefix, max_sequence_len=50): + src_vocab_file = vocab_prefix + "." + src_lang + tar_vocab_file = vocab_prefix + "." + tar_lang + + src_test_file = test_prefix + "." + src_lang + tar_test_file = test_prefix + "." + tar_lang + + src_vocab = _build_vocab(src_vocab_file) + tar_vocab = _build_vocab(tar_vocab_file) + + test_src, test_tar = _para_file_to_ids(src_test_file, tar_test_file, \ + src_vocab, tar_vocab) + + return (test_src, test_tar) + def raw_mono_data(vocab_file, file_path): src_vocab = _build_vocab(vocab_file) @@ -144,6 +192,18 @@ def raw_mono_data(vocab_file, file_path): return (test_src, test_tar) +def prepare_input(batch): + src_ids, src_mask, tar_ids, tar_mask = batch + res = {} + src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1])) + in_tar = tar_ids[:, :-1] + label_tar = tar_ids[:, 1:] + + in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1])) + label_tar = label_tar.reshape( + (label_tar.shape[0], label_tar.shape[1], 1)) + inputs = (src_ids, in_tar, label_tar, src_mask, tar_mask, np.sum(tar_mask)) + return inputs def get_data_iter(raw_data, batch_size, @@ -218,3 +278,97 @@ def get_data_iter(raw_data, src_ids, src_mask = to_pad_np(src_cache, source=True) tar_ids, tar_mask = to_pad_np(tar_cache) yield (src_ids, src_mask, tar_ids, tar_mask) + +def get_reader(batch_size, + src_lang, + tar_lang, + vocab_prefix, + data_prefix, + max_len=50, + reader_mode='train', + mode='train', + enable_ce=False, + cache_num=20): + def get_data_reader(): + def get_batch_data(): + if reader_mode == 'train': + raw_data = raw_train_data(src_lang, tar_lang, vocab_prefix, data_prefix, max_sequence_len=max_len) + elif reader_mode == 'valid': + raw_data = raw_eval_data(src_lang, tar_lang, vocab_prefix, data_prefix) + else: + raw_data = raw_test_data(src_lang, tar_lang, vocab_prefix, data_prefix) + + src_data, tar_data = raw_data + + data_len = len(src_data) + + index = np.arange(data_len) + if mode == "train" and not enable_ce: + np.random.shuffle(index) + + def to_pad_np(data, source=False): + max_len = 0 + bs = min(batch_size, len(data)) + for ele in data: + if len(ele) > max_len: + max_len = len(ele) + + ids = np.ones((bs, max_len), dtype='int64') * 2 + mask = np.zeros((bs), dtype='int32') + + for i, ele in enumerate(data): + ids[i, :len(ele)] = ele + if not source: + mask[i] = len(ele) - 1 + else: + mask[i] = len(ele) + + return ids, mask + + b_src = [] + + nonlocal cache_num + if mode != "train": + cache_num = 1 + for j in range(data_len): + if len(b_src) == batch_size * cache_num: + # build batch size + + # sort + if mode == 'infer': + new_cache = b_src + else: + new_cache = sorted(b_src, key=lambda k: len(k[0])) + + + for i in range(cache_num): + batch_data = new_cache[i * batch_size:(i + 1) * batch_size] + src_cache = [w[0] for w in batch_data] + tar_cache = [w[1] for w in batch_data] + src_ids, src_mask = to_pad_np(src_cache, source=True) + tar_ids, tar_mask = to_pad_np(tar_cache) + yield prepare_input((src_ids, src_mask, tar_ids, tar_mask)) + + b_src = [] + + b_src.append((src_data[index[j]], tar_data[index[j]])) + if len(b_src) == batch_size * cache_num or mode == 'infer': + if mode == 'infer': + new_cache = b_src + else: + new_cache = sorted(b_src, key=lambda k: len(k[0])) + + for i in range(cache_num): + batch_end = min(len(new_cache), (i + 1) * batch_size) + batch_data = new_cache[i * batch_size: batch_end] + src_cache = [w[0] for w in batch_data] + tar_cache = [w[1] for w in batch_data] + src_ids, src_mask = to_pad_np(src_cache, source=True) + tar_ids, tar_mask = to_pad_np(tar_cache) + yield prepare_input((src_ids, src_mask, tar_ids, tar_mask)) + + return get_batch_data + + data_reader = get_data_reader() + + return data_reader diff --git a/dygraph/seq2seq/train.py b/dygraph/seq2seq/train.py old mode 100755 new mode 100644 index a25dccc8..0ae32f65 --- a/dygraph/seq2seq/train.py +++ b/dygraph/seq2seq/train.py @@ -102,39 +102,31 @@ def main(): src_lang = args.src_lang tar_lang = args.tar_lang print("begin to load data") - raw_data = reader.raw_data(src_lang, tar_lang, vocab_prefix, - train_data_prefix, eval_data_prefix, - test_data_prefix, args.max_len) + train_data_iter = fluid.io.DataLoader.from_generator(capacity=32, use_double_buffer=True, iterable=True, return_list=True, use_multiprocess=True) + valid_data_iter = fluid.io.DataLoader.from_generator(capacity=32, use_double_buffer=True, iterable=True, return_list=True, use_multiprocess=True) + test_data_iter = fluid.io.DataLoader.from_generator(capacity=32, use_double_buffer=True, iterable=True, return_list=True, use_multiprocess=True) + + train_reader = reader.get_reader(batch_size, src_lang, tar_lang, vocab_prefix, train_data_prefix, max_len=args.max_len, reader_mode='train', enable_ce=args.enable_ce, cache_num=20) + train_data_iter.set_batch_generator(train_reader, place) + valid_reader = reader.get_reader(batch_size, src_lang, tar_lang, vocab_prefix, eval_data_prefix, reader_mode='valid', mode='eval', cache_num=20) + valid_data_iter.set_batch_generator(valid_reader, place) + test_reader = reader.get_reader(batch_size, src_lang, tar_lang, vocab_prefix, test_data_prefix, reader_mode='test', mode='eval', cache_num=20) + test_data_iter.set_batch_generator(test_reader, place) print("finished load data") - train_data, valid_data, test_data, _ = raw_data - - def prepare_input(batch, epoch_id=0): - src_ids, src_mask, tar_ids, tar_mask = batch - res = {} - src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1])) - in_tar = tar_ids[:, :-1] - label_tar = tar_ids[:, 1:] - - in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1])) - label_tar = label_tar.reshape( - (label_tar.shape[0], label_tar.shape[1], 1)) - inputs = [src_ids, in_tar, label_tar, src_mask, tar_mask] - return inputs, np.sum(tar_mask) # get train epoch size - def eval(data, epoch_id=0): + def eval(eval_data_iter, epoch_id=0): model.eval() - eval_data_iter = reader.get_data_iter(data, batch_size, mode='eval') total_loss = 0.0 word_count = 0.0 for batch_id, batch in enumerate(eval_data_iter): - input_data_feed, word_num = prepare_input( - batch, epoch_id) + input_data_feed1, input_data_feed2, input_data_feed3, input_data_feed4, input_data_feed5, word_num = batch + input_data_feed = [input_data_feed1, input_data_feed2, input_data_feed3, input_data_feed4, input_data_feed5] loss = model(input_data_feed) total_loss += loss * batch_size word_count += word_num - ppl = np.exp(total_loss.numpy() / word_count) + ppl = np.exp(total_loss.numpy() / word_count.numpy()) model.train() return ppl @@ -142,19 +134,14 @@ def main(): for epoch_id in range(max_epoch): model.train() start_time = time.time() - if args.enable_ce: - train_data_iter = reader.get_data_iter( - train_data, batch_size, enable_ce=True) - else: - train_data_iter = reader.get_data_iter(train_data, batch_size) total_loss = 0 word_count = 0.0 batch_times = [] for batch_id, batch in enumerate(train_data_iter): batch_start_time = time.time() - input_data_feed, word_num = prepare_input( - batch, epoch_id=epoch_id) + input_data_feed1, input_data_feed2, input_data_feed3, input_data_feed4, input_data_feed5, word_num = batch + input_data_feed = [input_data_feed1, input_data_feed2, input_data_feed3, input_data_feed4, input_data_feed5] word_count += word_num loss = model(input_data_feed) # print(loss.numpy()[0]) @@ -169,7 +156,7 @@ def main(): if batch_id > 0 and batch_id % 100 == 0: print("-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f" % (epoch_id, batch_id, batch_time, - np.exp(total_loss.numpy() / word_count))) + np.exp(total_loss.numpy() / word_count.numpy()))) total_loss = 0.0 word_count = 0.0 @@ -185,9 +172,9 @@ def main(): print("begin to save", dir_name) paddle.fluid.save_dygraph(model.state_dict(), dir_name) print("save finished") - dev_ppl = eval(valid_data) + dev_ppl = eval(valid_data_iter) print("dev ppl", dev_ppl) - test_ppl = eval(test_data) + test_ppl = eval(test_data_iter) print("test ppl", test_ppl) -- GitLab