From 2bb216b7d34d11047f7f6c9546f56d07222616d2 Mon Sep 17 00:00:00 2001 From: guosheng Date: Sat, 4 Apr 2020 18:43:22 +0800 Subject: [PATCH] Update seq2seq --- seq2seq/reader.py | 510 +++++++++++++++++++++++++--------------- seq2seq/seq2seq_attn.py | 13 +- seq2seq/seq2seq_base.py | 4 +- seq2seq/train.py | 86 ++++--- 4 files changed, 379 insertions(+), 234 deletions(-) diff --git a/seq2seq/reader.py b/seq2seq/reader.py index e562b6e..145acec 100644 --- a/seq2seq/reader.py +++ b/seq2seq/reader.py @@ -16,203 +16,333 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import os +import glob import io -import sys import numpy as np - -Py3 = sys.version_info[0] == 3 - -UNK_ID = 0 - - -def _read_words(filename): - data = [] - with io.open(filename, "r", encoding='utf-8') as f: - if Py3: - return f.read().replace("\n", "").split() +import itertools +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.fluid.io import BatchSampler, DataLoader, Dataset + + +def prepare_train_input(insts, bos_id, eos_id, pad_id): + src, src_length = pad_batch_data( + [inst[0] for inst in insts], pad_id=pad_id) + trg, trg_length = pad_batch_data( + [[bos_id] + inst[1] + [eos_id] for inst in insts], pad_id=pad_id) + trg_length = trg_length - 1 + return src, src_length, trg[:, :-1], trg_length, trg[:, 1:, np.newaxis] + + +def pad_batch_data(insts, pad_id): + """ + Pad the instances to the max sequence length in batch, and generate the + corresponding position data and attention bias. + """ + inst_lens = np.array([len(inst) for inst in insts], dtype="int64") + max_len = np.max(inst_lens) + inst_data = np.array( + [inst + [pad_id] * (max_len - len(inst)) for inst in insts], + dtype="int64") + return inst_data, inst_lens + + +class SortType(object): + GLOBAL = 'global' + POOL = 'pool' + NONE = "none" + + +class Converter(object): + def __init__(self, vocab, beg, end, unk, delimiter, add_beg, add_end): + self._vocab = vocab + self._beg = beg + self._end = end + self._unk = unk + self._delimiter = delimiter + self._add_beg = add_beg + self._add_end = add_end + + def __call__(self, sentence): + return ([self._beg] if self._add_beg else []) + [ + self._vocab.get(w, self._unk) + for w in sentence.split(self._delimiter) + ] + ([self._end] if self._add_end else []) + + +class ComposedConverter(object): + def __init__(self, converters): + self._converters = converters + + def __call__(self, fields): + return [ + converter(field) + for field, converter in zip(fields, self._converters) + ] + + +class SentenceBatchCreator(object): + def __init__(self, batch_size): + self.batch = [] + self._batch_size = batch_size + + def append(self, info): + self.batch.append(info) + if len(self.batch) == self._batch_size: + tmp = self.batch + self.batch = [] + return tmp + + +class TokenBatchCreator(object): + def __init__(self, batch_size): + self.batch = [] + self.max_len = -1 + self._batch_size = batch_size + + def append(self, info): + cur_len = info.max_len + max_len = max(self.max_len, cur_len) + if max_len * (len(self.batch) + 1) > self._batch_size: + result = self.batch + self.batch = [info] + self.max_len = cur_len + return result else: - return f.read().decode("utf-8").replace(u"\n", u"").split() - - -def read_all_line(filenam): - data = [] - with io.open(filename, "r", encoding='utf-8') as f: - for line in f.readlines(): - data.append(line.strip()) - - -def _build_vocab(filename): - - vocab_dict = {} - ids = 0 - with io.open(filename, "r", encoding='utf-8') as f: - for line in f.readlines(): - vocab_dict[line.strip()] = ids - ids += 1 - - print("vocab word num", ids) - - return vocab_dict - - -def _para_file_to_ids(src_file, tar_file, src_vocab, tar_vocab): - - src_data = [] - with io.open(src_file, "r", encoding='utf-8') as f_src: - for line in f_src.readlines(): - arra = line.strip().split() - ids = [src_vocab[w] if w in src_vocab else UNK_ID for w in arra] - ids = ids - - src_data.append(ids) - - tar_data = [] - with io.open(tar_file, "r", encoding='utf-8') as f_tar: - for line in f_tar.readlines(): - arra = line.strip().split() - ids = [tar_vocab[w] if w in tar_vocab else UNK_ID for w in arra] - - ids = [1] + ids + [2] - - tar_data.append(ids) - - return src_data, tar_data - - -def filter_len(src, tar, max_sequence_len=50): - new_src = [] - new_tar = [] - - for id1, id2 in zip(src, tar): - if len(id1) > max_sequence_len: - id1 = id1[:max_sequence_len] - if len(id2) > max_sequence_len + 2: - id2 = id2[:max_sequence_len + 2] - - new_src.append(id1) - new_tar.append(id2) - - return new_src, new_tar - + self.max_len = max_len + self.batch.append(info) -def raw_data(src_lang, - tar_lang, - vocab_prefix, - train_prefix, - eval_prefix, - test_prefix, - max_sequence_len=50): - src_vocab_file = vocab_prefix + "." + src_lang - tar_vocab_file = vocab_prefix + "." + tar_lang +class SampleInfo(object): + def __init__(self, i, max_len, min_len): + self.i = i + self.min_len = min_len + self.max_len = max_len - src_train_file = train_prefix + "." + src_lang - tar_train_file = train_prefix + "." + tar_lang - src_eval_file = eval_prefix + "." + src_lang - tar_eval_file = eval_prefix + "." + tar_lang +class MinMaxFilter(object): + def __init__(self, max_len, min_len, underlying_creator): + self._min_len = min_len + self._max_len = max_len + self._creator = underlying_creator - src_test_file = test_prefix + "." + src_lang - tar_test_file = test_prefix + "." + tar_lang - - src_vocab = _build_vocab(src_vocab_file) - tar_vocab = _build_vocab(tar_vocab_file) - - train_src, train_tar = _para_file_to_ids( src_train_file, tar_train_file, \ - src_vocab, tar_vocab ) - train_src, train_tar = filter_len( - train_src, train_tar, max_sequence_len=max_sequence_len) - eval_src, eval_tar = _para_file_to_ids( src_eval_file, tar_eval_file, \ - src_vocab, tar_vocab ) - - test_src, test_tar = _para_file_to_ids( src_test_file, tar_test_file, \ - src_vocab, tar_vocab ) - - return ( train_src, train_tar), (eval_src, eval_tar), (test_src, test_tar),\ - (src_vocab, tar_vocab) - - -def raw_mono_data(vocab_file, file_path): - - src_vocab = _build_vocab(vocab_file) - - test_src, test_tar = _para_file_to_ids( file_path, file_path, \ - src_vocab, src_vocab ) - - return (test_src, test_tar) - - -def get_data_iter(raw_data, - batch_size, - mode='train', - enable_ce=False, - cache_num=20): - - src_data, tar_data = raw_data - - data_len = len(src_data) - - index = np.arange(data_len) - if mode == "train" and not enable_ce: - np.random.shuffle(index) - - def to_pad_np(data, source=False): - max_len = 0 - bs = min(batch_size, len(data)) - for ele in data: - if len(ele) > max_len: - max_len = len(ele) - - ids = np.ones((bs, max_len), dtype='int64') * 2 - mask = np.zeros((bs), dtype='int32') - - for i, ele in enumerate(data): - ids[i, :len(ele)] = ele - if not source: - mask[i] = len(ele) - 1 - else: - mask[i] = len(ele) - - return ids, mask - - b_src = [] - - if mode != "train": - cache_num = 1 - for j in range(data_len): - if len(b_src) == batch_size * cache_num: - # build batch size - - # sort - if mode == 'infer': - new_cache = b_src + def append(self, info): + if info.max_len > self._max_len or info.min_len < self._min_len: + return + else: + return self._creator.append(info) + + @property + def batch(self): + return self._creator.batch + + +class Seq2SeqDataset(Dataset): + def __init__(self, + src_vocab_fpath, + trg_vocab_fpath, + fpattern, + field_delimiter="\t", + token_delimiter=" ", + start_mark="", + end_mark="", + unk_mark="", + only_src=False, + trg_fpattern=None): + # convert str to bytes, and use byte data + # field_delimiter = field_delimiter.encode("utf8") + # token_delimiter = token_delimiter.encode("utf8") + # start_mark = start_mark.encode("utf8") + # end_mark = end_mark.encode("utf8") + # unk_mark = unk_mark.encode("utf8") + self._src_vocab = self.load_dict(src_vocab_fpath) + self._trg_vocab = self.load_dict(trg_vocab_fpath) + self._bos_idx = self._src_vocab[start_mark] + self._eos_idx = self._src_vocab[end_mark] + self._unk_idx = self._src_vocab[unk_mark] + self._only_src = only_src + self._field_delimiter = field_delimiter + self._token_delimiter = token_delimiter + self.load_src_trg_ids(fpattern, trg_fpattern) + + def load_src_trg_ids(self, fpattern, trg_fpattern=None): + src_converter = Converter( + vocab=self._src_vocab, + beg=self._bos_idx, + end=self._eos_idx, + unk=self._unk_idx, + delimiter=self._token_delimiter, + add_beg=False, + add_end=False) + + trg_converter = Converter( + vocab=self._trg_vocab, + beg=self._bos_idx, + end=self._eos_idx, + unk=self._unk_idx, + delimiter=self._token_delimiter, + add_beg=False, + add_end=False) + + converters = ComposedConverter([src_converter, trg_converter]) + + self._src_seq_ids = [] + self._trg_seq_ids = [] + self._sample_infos = [] + + slots = [self._src_seq_ids, self._trg_seq_ids] + lens = [] + for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)): + lens = [] + for field, slot in zip(converters(line), slots): + slot.append(field) + lens.append(len(field)) + # self._sample_infos.append(SampleInfo(i, max(lens), min(lens))) + self._sample_infos.append(SampleInfo(i, lens[0], lens[0])) + + def _load_lines(self, fpattern, trg_fpattern=None): + fpaths = glob.glob(fpattern) + fpaths = sorted(fpaths) # TODO: Add custum sort + assert len(fpaths) > 0, "no matching file to the provided data path" + + if trg_fpattern is None: + for fpath in fpaths: + # with io.open(fpath, "rb") as f: + with io.open(fpath, "r", encoding="utf8") as f: + for line in f: + fields = line.strip("\n").split(self._field_delimiter) + yield fields + else: + # separated source and target language data files + # assume we can get aligned data by sort the two language files + # TODO: Need more rigorous check + trg_fpaths = glob.glob(trg_fpattern) + trg_fpaths = sorted(trg_fpaths) + assert len(fpaths) == len( + trg_fpaths + ), "the number of source language data files must equal \ + with that of source language" + + for fpath, trg_fpath in zip(fpaths, trg_fpaths): + # with io.open(fpath, "rb") as f: + # with io.open(trg_fpath, "rb") as trg_f: + with io.open(fpath, "r", encoding="utf8") as f: + with io.open(trg_fpath, "r", encoding="utf8") as trg_f: + for line in zip(f, trg_f): + fields = [field.strip("\n") for field in line] + yield fields + + @staticmethod + def load_dict(dict_path, reverse=False): + word_dict = {} + # with io.open(dict_path, "rb") as fdict: + with io.open(dict_path, "r", encoding="utf8") as fdict: + for idx, line in enumerate(fdict): + if reverse: + word_dict[idx] = line.strip("\n") + else: + word_dict[line.strip("\n")] = idx + return word_dict + + def get_vocab_summary(self): + return len(self._src_vocab), len( + self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx + + def __getitem__(self, idx): + return (self._src_seq_ids[idx], self._trg_seq_ids[idx] + ) if self._trg_seq_ids else self._src_seq_ids[idx] + + def __len__(self): + return len(self._sample_infos) + + +class Seq2SeqBatchSampler(BatchSampler): + def __init__(self, + dataset, + batch_size, + pool_size=10000, + sort_type=SortType.NONE, + min_length=0, + max_length=100, + shuffle=False, + shuffle_batch=False, + use_token_batch=False, + clip_last_batch=False, + seed=None): + for arg, value in locals().items(): + if arg != "self": + setattr(self, "_" + arg, value) + self._random = np.random + self._random.seed(seed) + # for multi-devices + self._nranks = ParallelEnv().nranks + self._local_rank = ParallelEnv().local_rank + self._device_id = ParallelEnv().dev_id + + def __iter__(self): + # global sort or global shuffle + if self._sort_type == SortType.GLOBAL: + infos = sorted( + self._dataset._sample_infos, key=lambda x: x.max_len) + else: + if self._shuffle: + infos = self._dataset._sample_infos + self._random.shuffle(infos) else: - new_cache = sorted(b_src, key=lambda k: len(k[0])) - - for i in range(cache_num): - batch_data = new_cache[i * batch_size:(i + 1) * batch_size] - src_cache = [w[0] for w in batch_data] - tar_cache = [w[1] for w in batch_data] - src_ids, src_mask = to_pad_np(src_cache, source=True) - tar_ids, tar_mask = to_pad_np(tar_cache) - yield (src_ids, src_mask, tar_ids, tar_mask) - - b_src = [] - - b_src.append((src_data[index[j]], tar_data[index[j]])) - if len(b_src) == batch_size * cache_num or mode == 'infer': - if mode == 'infer': - new_cache = b_src + infos = self._dataset._sample_infos + + if self._sort_type == SortType.POOL: + reverse = True + for i in range(0, len(infos), self._pool_size): + # to avoid placing short next to long sentences + reverse = not reverse + infos[i:i + self._pool_size] = sorted( + infos[i:i + self._pool_size], + key=lambda x: x.max_len, + reverse=reverse) + + batches = [] + batch_creator = TokenBatchCreator( + self. + _batch_size) if self._use_token_batch else SentenceBatchCreator( + self._batch_size * self._nranks) + batch_creator = MinMaxFilter(self._max_length, self._min_length, + batch_creator) + + for info in infos: + batch = batch_creator.append(info) + if batch is not None: + batches.append(batch) + + if not self._clip_last_batch and len(batch_creator.batch) != 0: + batches.append(batch_creator.batch) + + if self._shuffle_batch: + self._random.shuffle(batches) + + if not self._use_token_batch: + # when producing batches according to sequence number, to confirm + # neighbor batches which would be feed and run parallel have similar + # length (thus similar computational cost) after shuffle, we as take + # them as a whole when shuffling and split here + batches = [[ + batch[self._batch_size * i:self._batch_size * (i + 1)] + for i in range(self._nranks) + ] for batch in batches] + batches = list(itertools.chain.from_iterable(batches)) + + # for multi-device + for batch_id, batch in enumerate(batches): + if batch_id % self._nranks == self._local_rank: + batch_indices = [info.i for info in batch] + yield batch_indices + if self._local_rank > len(batches) % self._nranks: + yield batch_indices + + def __len__(self): + if not self._use_token_batch: + batch_number = ( + len(self._dataset) + self._batch_size * self._nranks - 1) // ( + self._batch_size * self._nranks) else: - new_cache = sorted(b_src, key=lambda k: len(k[0])) - - for i in range(cache_num): - batch_end = min(len(new_cache), (i + 1) * batch_size) - batch_data = new_cache[i * batch_size:batch_end] - src_cache = [w[0] for w in batch_data] - tar_cache = [w[1] for w in batch_data] - src_ids, src_mask = to_pad_np(src_cache, source=True) - tar_ids, tar_mask = to_pad_np(tar_cache) - yield (src_ids, src_mask, tar_ids, tar_mask) + batch_number = 100 + return batch_number diff --git a/seq2seq/seq2seq_attn.py b/seq2seq/seq2seq_attn.py index b71018e..599d25e 100644 --- a/seq2seq/seq2seq_attn.py +++ b/seq2seq/seq2seq_attn.py @@ -41,9 +41,10 @@ class AttentionLayer(Layer): bias_attr=bias) def forward(self, hidden, encoder_output, encoder_padding_mask): - query = self.input_proj(hidden) + # query = self.input_proj(hidden) + encoder_output = self.input_proj(encoder_output) attn_scores = layers.matmul( - layers.unsqueeze(query, [1]), encoder_output, transpose_y=True) + layers.unsqueeze(hidden, [1]), encoder_output, transpose_y=True) if encoder_padding_mask is not None: attn_scores = layers.elementwise_add(attn_scores, encoder_padding_mask) @@ -73,7 +74,9 @@ class DecoderCell(RNNCell): BasicLSTMCell( input_size=input_size + hidden_size if i == 0 else hidden_size, - hidden_size=hidden_size))) + hidden_size=hidden_size, + param_attr=ParamAttr(initializer=UniformInitializer( + low=-init_scale, high=init_scale))))) self.attention_layer = AttentionLayer(hidden_size) def forward(self, @@ -107,8 +110,8 @@ class Decoder(Layer): size=[vocab_size, embed_dim], param_attr=ParamAttr(initializer=UniformInitializer( low=-init_scale, high=init_scale))) - self.lstm_attention = RNN(DecoderCell(num_layers, embed_dim, - hidden_size, init_scale), + self.lstm_attention = RNN(DecoderCell( + num_layers, embed_dim, hidden_size, dropout_prob, init_scale), is_reverse=False, time_major=False) self.output_layer = Linear( diff --git a/seq2seq/seq2seq_base.py b/seq2seq/seq2seq_base.py index f56a873..ae2cb4b 100644 --- a/seq2seq/seq2seq_base.py +++ b/seq2seq/seq2seq_base.py @@ -86,7 +86,7 @@ class Encoder(Layer): param_attr=ParamAttr(initializer=UniformInitializer( low=-init_scale, high=init_scale))) self.stack_lstm = RNN(EncoderCell(num_layers, embed_dim, hidden_size, - init_scale), + dropout_prob, init_scale), is_reverse=False, time_major=False) @@ -114,7 +114,7 @@ class Decoder(Layer): param_attr=ParamAttr(initializer=UniformInitializer( low=-init_scale, high=init_scale))) self.stack_lstm = RNN(DecoderCell(num_layers, embed_dim, hidden_size, - init_scale), + dropout_prob, init_scale), is_reverse=False, time_major=False) self.output_layer = Linear( diff --git a/seq2seq/train.py b/seq2seq/train.py index 3ca8ae6..70f9315 100644 --- a/seq2seq/train.py +++ b/seq2seq/train.py @@ -17,8 +17,7 @@ import os import six import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -import time -import contextlib +import random from functools import partial import numpy as np @@ -34,16 +33,17 @@ from seq2seq_base import BaseModel, CrossEntropyCriterion from seq2seq_attn import AttentionModel from model import Input, set_device from callbacks import ProgBarLogger -from metrics import Metric - - -class PPL(Metric): - pass +from reader import Seq2SeqDataset, Seq2SeqBatchSampler, SortType, prepare_train_input def do_train(args): device = set_device("gpu" if args.use_gpu else "cpu") - fluid.enable_dygraph(device) #if args.eager_run else None + fluid.enable_dygraph(device) if args.eager_run else None + + if args.enable_ce: + fluid.default_main_program().random_seed = 102 + fluid.default_startup_program().random_seed = 102 + args.shuffle = False # define model inputs = [ @@ -58,6 +58,45 @@ def do_train(args): ] labels = [Input([None, None, 1], "int64", name="label"), ] + # def dataloader + data_loaders = [None, None] + data_prefixes = [args.train_data_prefix, args.eval_data_prefix + ] if args.eval_data_prefix else [args.train_data_prefix] + for i, data_prefix in enumerate(data_prefixes): + dataset = Seq2SeqDataset( + fpattern=data_prefix + "." + args.src_lang, + trg_fpattern=data_prefix + "." + args.tar_lang, + src_vocab_fpath=args.vocab_prefix + "." + args.src_lang, + trg_vocab_fpath=args.vocab_prefix + "." + args.tar_lang, + token_delimiter=None, + start_mark="", + end_mark="", + unk_mark="") + (args.src_vocab_size, args.trg_vocab_size, bos_id, eos_id, + unk_id) = dataset.get_vocab_summary() + batch_sampler = Seq2SeqBatchSampler( + dataset=dataset, + use_token_batch=False, + batch_size=args.batch_size, + pool_size=args.batch_size * 20, + sort_type=SortType.POOL, + shuffle=args.shuffle) + data_loader = DataLoader( + dataset=dataset, + batch_sampler=batch_sampler, + places=device, + feed_list=None if fluid.in_dygraph_mode() else + [x.forward() for x in inputs + labels], + collate_fn=partial( + prepare_train_input, + bos_id=bos_id, + eos_id=eos_id, + pad_id=eos_id), + num_workers=0, + return_list=True) + data_loaders[i] = data_loader + train_loader, eval_loader = data_loaders + model = AttentionModel(args.src_vocab_size, args.tar_vocab_size, args.hidden_size, args.hidden_size, args.num_layers, args.dropout) @@ -69,39 +108,12 @@ def do_train(args): CrossEntropyCriterion(), inputs=inputs, labels=labels) - - batch_size = 32 - src_seq_len = 10 - trg_seq_len = 12 - iter_num = 10 - - def random_generator(): - for i in range(iter_num): - src = np.random.randint(2, args.src_vocab_size, - (batch_size, src_seq_len)).astype("int64") - src_length = np.random.randint(1, src_seq_len, - (batch_size, )).astype("int64") - trg = np.random.randint(2, args.tar_vocab_size, - (batch_size, trg_seq_len)).astype("int64") - trg_length = np.random.randint(1, trg_seq_len, - (batch_size, )).astype("int64") - label = np.random.randint( - 1, trg_seq_len, (batch_size, trg_seq_len, 1)).astype("int64") - yield src, src_length, trg, trg_length, label - - model.fit(train_data=random_generator, log_freq=1) - exit(0) - - data_loaders = [None, None] - data_files = [args.training_file, args.validation_file - ] if args.validation_file else [args.training_file] - train_loader, eval_loader = data_loaders - model.fit(train_data=train_loader, - eval_data=None, + eval_data=eval_loader, epochs=1, eval_freq=1, save_freq=1, + log_freq=1, verbose=2) -- GitLab