提交 2bb216b7 编写于 作者: G guosheng

Update seq2seq

上级 f91528a9
...@@ -16,203 +16,333 @@ from __future__ import absolute_import ...@@ -16,203 +16,333 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import glob
import os
import io import io
import sys
import numpy as np import numpy as np
import itertools
Py3 = sys.version_info[0] == 3 from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import BatchSampler, DataLoader, Dataset
UNK_ID = 0
def prepare_train_input(insts, bos_id, eos_id, pad_id):
def _read_words(filename): src, src_length = pad_batch_data(
data = [] [inst[0] for inst in insts], pad_id=pad_id)
with io.open(filename, "r", encoding='utf-8') as f: trg, trg_length = pad_batch_data(
if Py3: [[bos_id] + inst[1] + [eos_id] for inst in insts], pad_id=pad_id)
return f.read().replace("\n", "<eos>").split() trg_length = trg_length - 1
return src, src_length, trg[:, :-1], trg_length, trg[:, 1:, np.newaxis]
def pad_batch_data(insts, pad_id):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
inst_lens = np.array([len(inst) for inst in insts], dtype="int64")
max_len = np.max(inst_lens)
inst_data = np.array(
[inst + [pad_id] * (max_len - len(inst)) for inst in insts],
dtype="int64")
return inst_data, inst_lens
class SortType(object):
GLOBAL = 'global'
POOL = 'pool'
NONE = "none"
class Converter(object):
def __init__(self, vocab, beg, end, unk, delimiter, add_beg, add_end):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
self._delimiter = delimiter
self._add_beg = add_beg
self._add_end = add_end
def __call__(self, sentence):
return ([self._beg] if self._add_beg else []) + [
self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter)
] + ([self._end] if self._add_end else [])
class ComposedConverter(object):
def __init__(self, converters):
self._converters = converters
def __call__(self, fields):
return [
converter(field)
for field, converter in zip(fields, self._converters)
]
class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self.batch) + 1) > self._batch_size:
result = self.batch
self.batch = [info]
self.max_len = cur_len
return result
else: else:
return f.read().decode("utf-8").replace(u"\n", u"<eos>").split() self.max_len = max_len
self.batch.append(info)
def read_all_line(filenam):
data = []
with io.open(filename, "r", encoding='utf-8') as f:
for line in f.readlines():
data.append(line.strip())
def _build_vocab(filename):
vocab_dict = {}
ids = 0
with io.open(filename, "r", encoding='utf-8') as f:
for line in f.readlines():
vocab_dict[line.strip()] = ids
ids += 1
print("vocab word num", ids)
return vocab_dict
def _para_file_to_ids(src_file, tar_file, src_vocab, tar_vocab):
src_data = []
with io.open(src_file, "r", encoding='utf-8') as f_src:
for line in f_src.readlines():
arra = line.strip().split()
ids = [src_vocab[w] if w in src_vocab else UNK_ID for w in arra]
ids = ids
src_data.append(ids)
tar_data = []
with io.open(tar_file, "r", encoding='utf-8') as f_tar:
for line in f_tar.readlines():
arra = line.strip().split()
ids = [tar_vocab[w] if w in tar_vocab else UNK_ID for w in arra]
ids = [1] + ids + [2]
tar_data.append(ids)
return src_data, tar_data
def filter_len(src, tar, max_sequence_len=50):
new_src = []
new_tar = []
for id1, id2 in zip(src, tar):
if len(id1) > max_sequence_len:
id1 = id1[:max_sequence_len]
if len(id2) > max_sequence_len + 2:
id2 = id2[:max_sequence_len + 2]
new_src.append(id1)
new_tar.append(id2)
return new_src, new_tar
def raw_data(src_lang,
tar_lang,
vocab_prefix,
train_prefix,
eval_prefix,
test_prefix,
max_sequence_len=50):
src_vocab_file = vocab_prefix + "." + src_lang class SampleInfo(object):
tar_vocab_file = vocab_prefix + "." + tar_lang def __init__(self, i, max_len, min_len):
self.i = i
self.min_len = min_len
self.max_len = max_len
src_train_file = train_prefix + "." + src_lang
tar_train_file = train_prefix + "." + tar_lang
src_eval_file = eval_prefix + "." + src_lang class MinMaxFilter(object):
tar_eval_file = eval_prefix + "." + tar_lang def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
src_test_file = test_prefix + "." + src_lang def append(self, info):
tar_test_file = test_prefix + "." + tar_lang if info.max_len > self._max_len or info.min_len < self._min_len:
return
src_vocab = _build_vocab(src_vocab_file) else:
tar_vocab = _build_vocab(tar_vocab_file) return self._creator.append(info)
train_src, train_tar = _para_file_to_ids( src_train_file, tar_train_file, \ @property
src_vocab, tar_vocab ) def batch(self):
train_src, train_tar = filter_len( return self._creator.batch
train_src, train_tar, max_sequence_len=max_sequence_len)
eval_src, eval_tar = _para_file_to_ids( src_eval_file, tar_eval_file, \
src_vocab, tar_vocab ) class Seq2SeqDataset(Dataset):
def __init__(self,
test_src, test_tar = _para_file_to_ids( src_test_file, tar_test_file, \ src_vocab_fpath,
src_vocab, tar_vocab ) trg_vocab_fpath,
fpattern,
return ( train_src, train_tar), (eval_src, eval_tar), (test_src, test_tar),\ field_delimiter="\t",
(src_vocab, tar_vocab) token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
def raw_mono_data(vocab_file, file_path): unk_mark="<unk>",
only_src=False,
src_vocab = _build_vocab(vocab_file) trg_fpattern=None):
# convert str to bytes, and use byte data
test_src, test_tar = _para_file_to_ids( file_path, file_path, \ # field_delimiter = field_delimiter.encode("utf8")
src_vocab, src_vocab ) # token_delimiter = token_delimiter.encode("utf8")
# start_mark = start_mark.encode("utf8")
return (test_src, test_tar) # end_mark = end_mark.encode("utf8")
# unk_mark = unk_mark.encode("utf8")
self._src_vocab = self.load_dict(src_vocab_fpath)
def get_data_iter(raw_data, self._trg_vocab = self.load_dict(trg_vocab_fpath)
batch_size, self._bos_idx = self._src_vocab[start_mark]
mode='train', self._eos_idx = self._src_vocab[end_mark]
enable_ce=False, self._unk_idx = self._src_vocab[unk_mark]
cache_num=20): self._only_src = only_src
self._field_delimiter = field_delimiter
src_data, tar_data = raw_data self._token_delimiter = token_delimiter
self.load_src_trg_ids(fpattern, trg_fpattern)
data_len = len(src_data)
def load_src_trg_ids(self, fpattern, trg_fpattern=None):
index = np.arange(data_len) src_converter = Converter(
if mode == "train" and not enable_ce: vocab=self._src_vocab,
np.random.shuffle(index) beg=self._bos_idx,
end=self._eos_idx,
def to_pad_np(data, source=False): unk=self._unk_idx,
max_len = 0 delimiter=self._token_delimiter,
bs = min(batch_size, len(data)) add_beg=False,
for ele in data: add_end=False)
if len(ele) > max_len:
max_len = len(ele) trg_converter = Converter(
vocab=self._trg_vocab,
ids = np.ones((bs, max_len), dtype='int64') * 2 beg=self._bos_idx,
mask = np.zeros((bs), dtype='int32') end=self._eos_idx,
unk=self._unk_idx,
for i, ele in enumerate(data): delimiter=self._token_delimiter,
ids[i, :len(ele)] = ele add_beg=False,
if not source: add_end=False)
mask[i] = len(ele) - 1
else: converters = ComposedConverter([src_converter, trg_converter])
mask[i] = len(ele)
self._src_seq_ids = []
return ids, mask self._trg_seq_ids = []
self._sample_infos = []
b_src = []
slots = [self._src_seq_ids, self._trg_seq_ids]
if mode != "train": lens = []
cache_num = 1 for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)):
for j in range(data_len): lens = []
if len(b_src) == batch_size * cache_num: for field, slot in zip(converters(line), slots):
# build batch size slot.append(field)
lens.append(len(field))
# sort # self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
if mode == 'infer': self._sample_infos.append(SampleInfo(i, lens[0], lens[0]))
new_cache = b_src
def _load_lines(self, fpattern, trg_fpattern=None):
fpaths = glob.glob(fpattern)
fpaths = sorted(fpaths) # TODO: Add custum sort
assert len(fpaths) > 0, "no matching file to the provided data path"
if trg_fpattern is None:
for fpath in fpaths:
# with io.open(fpath, "rb") as f:
with io.open(fpath, "r", encoding="utf8") as f:
for line in f:
fields = line.strip("\n").split(self._field_delimiter)
yield fields
else:
# separated source and target language data files
# assume we can get aligned data by sort the two language files
# TODO: Need more rigorous check
trg_fpaths = glob.glob(trg_fpattern)
trg_fpaths = sorted(trg_fpaths)
assert len(fpaths) == len(
trg_fpaths
), "the number of source language data files must equal \
with that of source language"
for fpath, trg_fpath in zip(fpaths, trg_fpaths):
# with io.open(fpath, "rb") as f:
# with io.open(trg_fpath, "rb") as trg_f:
with io.open(fpath, "r", encoding="utf8") as f:
with io.open(trg_fpath, "r", encoding="utf8") as trg_f:
for line in zip(f, trg_f):
fields = [field.strip("\n") for field in line]
yield fields
@staticmethod
def load_dict(dict_path, reverse=False):
word_dict = {}
# with io.open(dict_path, "rb") as fdict:
with io.open(dict_path, "r", encoding="utf8") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip("\n")
else:
word_dict[line.strip("\n")] = idx
return word_dict
def get_vocab_summary(self):
return len(self._src_vocab), len(
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
def __getitem__(self, idx):
return (self._src_seq_ids[idx], self._trg_seq_ids[idx]
) if self._trg_seq_ids else self._src_seq_ids[idx]
def __len__(self):
return len(self._sample_infos)
class Seq2SeqBatchSampler(BatchSampler):
def __init__(self,
dataset,
batch_size,
pool_size=10000,
sort_type=SortType.NONE,
min_length=0,
max_length=100,
shuffle=False,
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
seed=None):
for arg, value in locals().items():
if arg != "self":
setattr(self, "_" + arg, value)
self._random = np.random
self._random.seed(seed)
# for multi-devices
self._nranks = ParallelEnv().nranks
self._local_rank = ParallelEnv().local_rank
self._device_id = ParallelEnv().dev_id
def __iter__(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(
self._dataset._sample_infos, key=lambda x: x.max_len)
else:
if self._shuffle:
infos = self._dataset._sample_infos
self._random.shuffle(infos)
else: else:
new_cache = sorted(b_src, key=lambda k: len(k[0])) infos = self._dataset._sample_infos
for i in range(cache_num): if self._sort_type == SortType.POOL:
batch_data = new_cache[i * batch_size:(i + 1) * batch_size] reverse = True
src_cache = [w[0] for w in batch_data] for i in range(0, len(infos), self._pool_size):
tar_cache = [w[1] for w in batch_data] # to avoid placing short next to long sentences
src_ids, src_mask = to_pad_np(src_cache, source=True) reverse = not reverse
tar_ids, tar_mask = to_pad_np(tar_cache) infos[i:i + self._pool_size] = sorted(
yield (src_ids, src_mask, tar_ids, tar_mask) infos[i:i + self._pool_size],
key=lambda x: x.max_len,
b_src = [] reverse=reverse)
b_src.append((src_data[index[j]], tar_data[index[j]])) batches = []
if len(b_src) == batch_size * cache_num or mode == 'infer': batch_creator = TokenBatchCreator(
if mode == 'infer': self.
new_cache = b_src _batch_size) if self._use_token_batch else SentenceBatchCreator(
self._batch_size * self._nranks)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)
if self._shuffle_batch:
self._random.shuffle(batches)
if not self._use_token_batch:
# when producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches = [[
batch[self._batch_size * i:self._batch_size * (i + 1)]
for i in range(self._nranks)
] for batch in batches]
batches = list(itertools.chain.from_iterable(batches))
# for multi-device
for batch_id, batch in enumerate(batches):
if batch_id % self._nranks == self._local_rank:
batch_indices = [info.i for info in batch]
yield batch_indices
if self._local_rank > len(batches) % self._nranks:
yield batch_indices
def __len__(self):
if not self._use_token_batch:
batch_number = (
len(self._dataset) + self._batch_size * self._nranks - 1) // (
self._batch_size * self._nranks)
else: else:
new_cache = sorted(b_src, key=lambda k: len(k[0])) batch_number = 100
return batch_number
for i in range(cache_num):
batch_end = min(len(new_cache), (i + 1) * batch_size)
batch_data = new_cache[i * batch_size:batch_end]
src_cache = [w[0] for w in batch_data]
tar_cache = [w[1] for w in batch_data]
src_ids, src_mask = to_pad_np(src_cache, source=True)
tar_ids, tar_mask = to_pad_np(tar_cache)
yield (src_ids, src_mask, tar_ids, tar_mask)
...@@ -41,9 +41,10 @@ class AttentionLayer(Layer): ...@@ -41,9 +41,10 @@ class AttentionLayer(Layer):
bias_attr=bias) bias_attr=bias)
def forward(self, hidden, encoder_output, encoder_padding_mask): def forward(self, hidden, encoder_output, encoder_padding_mask):
query = self.input_proj(hidden) # query = self.input_proj(hidden)
encoder_output = self.input_proj(encoder_output)
attn_scores = layers.matmul( attn_scores = layers.matmul(
layers.unsqueeze(query, [1]), encoder_output, transpose_y=True) layers.unsqueeze(hidden, [1]), encoder_output, transpose_y=True)
if encoder_padding_mask is not None: if encoder_padding_mask is not None:
attn_scores = layers.elementwise_add(attn_scores, attn_scores = layers.elementwise_add(attn_scores,
encoder_padding_mask) encoder_padding_mask)
...@@ -73,7 +74,9 @@ class DecoderCell(RNNCell): ...@@ -73,7 +74,9 @@ class DecoderCell(RNNCell):
BasicLSTMCell( BasicLSTMCell(
input_size=input_size + hidden_size input_size=input_size + hidden_size
if i == 0 else hidden_size, if i == 0 else hidden_size,
hidden_size=hidden_size))) hidden_size=hidden_size,
param_attr=ParamAttr(initializer=UniformInitializer(
low=-init_scale, high=init_scale)))))
self.attention_layer = AttentionLayer(hidden_size) self.attention_layer = AttentionLayer(hidden_size)
def forward(self, def forward(self,
...@@ -107,8 +110,8 @@ class Decoder(Layer): ...@@ -107,8 +110,8 @@ class Decoder(Layer):
size=[vocab_size, embed_dim], size=[vocab_size, embed_dim],
param_attr=ParamAttr(initializer=UniformInitializer( param_attr=ParamAttr(initializer=UniformInitializer(
low=-init_scale, high=init_scale))) low=-init_scale, high=init_scale)))
self.lstm_attention = RNN(DecoderCell(num_layers, embed_dim, self.lstm_attention = RNN(DecoderCell(
hidden_size, init_scale), num_layers, embed_dim, hidden_size, dropout_prob, init_scale),
is_reverse=False, is_reverse=False,
time_major=False) time_major=False)
self.output_layer = Linear( self.output_layer = Linear(
......
...@@ -86,7 +86,7 @@ class Encoder(Layer): ...@@ -86,7 +86,7 @@ class Encoder(Layer):
param_attr=ParamAttr(initializer=UniformInitializer( param_attr=ParamAttr(initializer=UniformInitializer(
low=-init_scale, high=init_scale))) low=-init_scale, high=init_scale)))
self.stack_lstm = RNN(EncoderCell(num_layers, embed_dim, hidden_size, self.stack_lstm = RNN(EncoderCell(num_layers, embed_dim, hidden_size,
init_scale), dropout_prob, init_scale),
is_reverse=False, is_reverse=False,
time_major=False) time_major=False)
...@@ -114,7 +114,7 @@ class Decoder(Layer): ...@@ -114,7 +114,7 @@ class Decoder(Layer):
param_attr=ParamAttr(initializer=UniformInitializer( param_attr=ParamAttr(initializer=UniformInitializer(
low=-init_scale, high=init_scale))) low=-init_scale, high=init_scale)))
self.stack_lstm = RNN(DecoderCell(num_layers, embed_dim, hidden_size, self.stack_lstm = RNN(DecoderCell(num_layers, embed_dim, hidden_size,
init_scale), dropout_prob, init_scale),
is_reverse=False, is_reverse=False,
time_major=False) time_major=False)
self.output_layer = Linear( self.output_layer = Linear(
......
...@@ -17,8 +17,7 @@ import os ...@@ -17,8 +17,7 @@ import os
import six import six
import sys import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import time import random
import contextlib
from functools import partial from functools import partial
import numpy as np import numpy as np
...@@ -34,16 +33,17 @@ from seq2seq_base import BaseModel, CrossEntropyCriterion ...@@ -34,16 +33,17 @@ from seq2seq_base import BaseModel, CrossEntropyCriterion
from seq2seq_attn import AttentionModel from seq2seq_attn import AttentionModel
from model import Input, set_device from model import Input, set_device
from callbacks import ProgBarLogger from callbacks import ProgBarLogger
from metrics import Metric from reader import Seq2SeqDataset, Seq2SeqBatchSampler, SortType, prepare_train_input
class PPL(Metric):
pass
def do_train(args): def do_train(args):
device = set_device("gpu" if args.use_gpu else "cpu") device = set_device("gpu" if args.use_gpu else "cpu")
fluid.enable_dygraph(device) #if args.eager_run else None fluid.enable_dygraph(device) if args.eager_run else None
if args.enable_ce:
fluid.default_main_program().random_seed = 102
fluid.default_startup_program().random_seed = 102
args.shuffle = False
# define model # define model
inputs = [ inputs = [
...@@ -58,6 +58,45 @@ def do_train(args): ...@@ -58,6 +58,45 @@ def do_train(args):
] ]
labels = [Input([None, None, 1], "int64", name="label"), ] labels = [Input([None, None, 1], "int64", name="label"), ]
# def dataloader
data_loaders = [None, None]
data_prefixes = [args.train_data_prefix, args.eval_data_prefix
] if args.eval_data_prefix else [args.train_data_prefix]
for i, data_prefix in enumerate(data_prefixes):
dataset = Seq2SeqDataset(
fpattern=data_prefix + "." + args.src_lang,
trg_fpattern=data_prefix + "." + args.tar_lang,
src_vocab_fpath=args.vocab_prefix + "." + args.src_lang,
trg_vocab_fpath=args.vocab_prefix + "." + args.tar_lang,
token_delimiter=None,
start_mark="<s>",
end_mark="</s>",
unk_mark="<unk>")
(args.src_vocab_size, args.trg_vocab_size, bos_id, eos_id,
unk_id) = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=False,
batch_size=args.batch_size,
pool_size=args.batch_size * 20,
sort_type=SortType.POOL,
shuffle=args.shuffle)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
feed_list=None if fluid.in_dygraph_mode() else
[x.forward() for x in inputs + labels],
collate_fn=partial(
prepare_train_input,
bos_id=bos_id,
eos_id=eos_id,
pad_id=eos_id),
num_workers=0,
return_list=True)
data_loaders[i] = data_loader
train_loader, eval_loader = data_loaders
model = AttentionModel(args.src_vocab_size, args.tar_vocab_size, model = AttentionModel(args.src_vocab_size, args.tar_vocab_size,
args.hidden_size, args.hidden_size, args.num_layers, args.hidden_size, args.hidden_size, args.num_layers,
args.dropout) args.dropout)
...@@ -69,39 +108,12 @@ def do_train(args): ...@@ -69,39 +108,12 @@ def do_train(args):
CrossEntropyCriterion(), CrossEntropyCriterion(),
inputs=inputs, inputs=inputs,
labels=labels) labels=labels)
batch_size = 32
src_seq_len = 10
trg_seq_len = 12
iter_num = 10
def random_generator():
for i in range(iter_num):
src = np.random.randint(2, args.src_vocab_size,
(batch_size, src_seq_len)).astype("int64")
src_length = np.random.randint(1, src_seq_len,
(batch_size, )).astype("int64")
trg = np.random.randint(2, args.tar_vocab_size,
(batch_size, trg_seq_len)).astype("int64")
trg_length = np.random.randint(1, trg_seq_len,
(batch_size, )).astype("int64")
label = np.random.randint(
1, trg_seq_len, (batch_size, trg_seq_len, 1)).astype("int64")
yield src, src_length, trg, trg_length, label
model.fit(train_data=random_generator, log_freq=1)
exit(0)
data_loaders = [None, None]
data_files = [args.training_file, args.validation_file
] if args.validation_file else [args.training_file]
train_loader, eval_loader = data_loaders
model.fit(train_data=train_loader, model.fit(train_data=train_loader,
eval_data=None, eval_data=eval_loader,
epochs=1, epochs=1,
eval_freq=1, eval_freq=1,
save_freq=1, save_freq=1,
log_freq=1,
verbose=2) verbose=2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册