提交 68dfe864 编写于 作者: G guosheng

Update Transformer

上级 64a7e1a0
......@@ -16,18 +16,68 @@ import glob
import six
import os
import tarfile
import itertools
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.io import BatchSampler, DataLoader
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import BatchSampler, DataLoader, Dataset
class TokenBatchSampler(BatchSampler):
def __init__(self):
pass
def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
"""
Put all padded data needed by training into a list.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len)
trg_pos = trg_pos.reshape(-1, trg_max_len)
def __iter(self):
pass
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False,
return_num_token=True)
lbl_word = lbl_word.reshape(-1, 1)
lbl_weight = lbl_weight.reshape(-1, 1)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
]
return data_inputs
def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head):
"""
Put all padded data needed by beam search decoder into a list.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
# start tokens
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
trg_word = trg_word.reshape(-1, 1)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias
]
return data_inputs
def pad_batch_data(insts,
......@@ -88,60 +138,206 @@ def pad_batch_data(insts,
return return_list if len(return_list) > 1 else return_list[0]
def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
"""
Put all padded data needed by training into a list.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len)
trg_pos = trg_pos.reshape(-1, trg_max_len)
class Seq2SeqDataset(Dataset):
def __init__(self,
src_vocab_fpath,
trg_vocab_fpath,
fpattern,
tar_fname=None,
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
only_src=False):
# convert str to bytes, and use byte data
field_delimiter = field_delimiter.encode("utf8")
token_delimiter = token_delimiter.encode("utf8")
start_mark = start_mark.encode("utf8")
end_mark = end_mark.encode("utf8")
unk_mark = unk_mark.encode("utf8")
self._src_vocab = self.load_dict(src_vocab_fpath)
self._trg_vocab = self.load_dict(trg_vocab_fpath)
self._bos_idx = self._src_vocab[start_mark]
self._eos_idx = self._src_vocab[end_mark]
self._unk_idx = self._src_vocab[unk_mark]
self._only_src = only_src
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self.load_src_trg_ids(fpattern, tar_fname)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
def load_src_trg_ids(self, fpattern, tar_fname):
converters = [
Converter(
vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False)
]
if not self._only_src:
converters.append(
Converter(
vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=True))
lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False,
return_num_token=True)
lbl_word = lbl_word.reshape(-1, 1)
lbl_weight = lbl_weight.reshape(-1, 1)
converters = ComposedConverter(converters)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
]
self._src_seq_ids = []
self._trg_seq_ids = None if self._only_src else []
self._sample_infos = []
return data_inputs
for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
src_trg_ids = converters(line)
self._src_seq_ids.append(src_trg_ids[0])
lens = [len(src_trg_ids[0])]
if not self._only_src:
self._trg_seq_ids.append(src_trg_ids[1])
lens.append(len(src_trg_ids[1]))
self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
def _load_lines(self, fpattern, tar_fname):
fpaths = glob.glob(fpattern)
assert len(fpaths) > 0, "no matching file to the provided data path"
def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head):
"""
Put all padded data needed by beam search decoder into a list.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
# start tokens
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
trg_word = trg_word.reshape(-1, 1)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.")
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias
]
return data_inputs
f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname):
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
else:
for fpath in fpaths:
if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath)
with open(fpath, "rb") as f:
for line in f:
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
@staticmethod
def load_dict(dict_path, reverse=False):
word_dict = {}
with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip(b"\n")
else:
word_dict[line.strip(b"\n")] = idx
return word_dict
def get_vocab_summary(self):
return len(self._src_vocab), len(
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
def __getitem__(self, idx):
return (self._src_seq_ids[idx], self._trg_seq_ids[idx]
) if not self._only_src else self._src_seq_ids[idx]
def __len__(self):
return len(self._sample_infos)
class Seq2SeqBatchSampler(BatchSampler):
def __init__(self,
dataset,
batch_size,
pool_size,
sort_type=SortType.GLOBAL,
min_length=0,
max_length=100,
shuffle=True,
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
seed=0):
for arg, value in locals().items():
if arg != "self":
setattr(self, "_" + arg, value)
self._random = np.random
self._random.seed(seed)
# for multi-devices
self._nranks = ParallelEnv().nranks
self._local_rank = ParallelEnv().local_rank
self._device_id = ParallelEnv().dev_id
def __iter__(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self.dataset._sample_infos, key=lambda x: x.max_len)
else:
if self._shuffle:
infos = self.dataset._sample_infos
self._random.shuffle(infos)
else:
infos = self.dataset._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# to avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
batches = []
batch_creator = TokenBatchCreator(
self.
_batch_size) if self._use_token_batch else SentenceBatchCreator(
self._batch_size * self._nranks)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)
if self._shuffle_batch:
self._random.shuffle(batches)
if not self._use_token_batch:
# when producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches = [[
batch[self._batch_size * i:self._batch_size * (i + 1)]
for i in range(self._nranks)
] for batch in batches]
batches = itertools.chain.from_iterable(batches)
# for multi-device
for batch_id, batch in enumerate(batches):
if batch_id % self._nranks == self._local_rank:
batch_indices = [info.i for info in batch]
yield batch_indices
if self._local_rank > len(batches) % self._nranks:
yield batch_indices
def __len__(self):
pass
@property
def dev_id(self):
return self._dev_id
class SortType(object):
......
......@@ -24,6 +24,7 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
from paddle.fluid.io import DataLoader
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
......@@ -38,6 +39,7 @@ from callbacks import ProgBarLogger
class LoggerCallback(ProgBarLogger):
def __init__(self, log_freq=1, verbose=2, loss_normalizer=0.):
super(LoggerCallback, self).__init__(log_freq, verbose)
# TODO: wrap these override function to simplify
self.loss_normalizer = loss_normalizer
def on_train_begin(self, logs=None):
......@@ -60,148 +62,111 @@ class LoggerCallback(ProgBarLogger):
def do_train(args):
init_context('dynamic' if FLAGS.dynamic else 'static')
trainer_count = 1 #get_nranks()
@contextlib.contextmanager
def null_guard():
yield
guard = fluid.dygraph.guard() if args.eager_run else null_guard()
# define the data generator
processor = reader.DataProcessor(
# init_context('dynamic' if FLAGS.dynamic else 'static')
# set seed for CE
random_seed = eval(str(args.random_seed))
if random_seed is not None:
fluid.default_main_program().random_seed = random_seed
fluid.default_startup_program().random_seed = random_seed
# define model
inputs = [
Input(
[None, None], "int64", name="src_word"), Input(
[None, None], "int64", name="src_pos"), Input(
[None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"), Input(
[None, None], "int64", name="trg_word"), Input(
[None, None], "int64", name="trg_pos"), Input(
[None, args.n_head, None, None],
"float32",
name="trg_slf_attn_bias"), Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias")
]
labels = [
Input(
[None, 1], "int64", name="label"),
Input(
[None, 1], "float32", name="weight"),
]
dataset = reader.Seq2SeqDataset(
fpattern=args.training_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = reader.Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
device_count=trainer_count,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
batch_generator = processor.data_generator(phase="train")
if trainer_count > 1: # for multi-process gpu training
batch_generator = fluid.contrib.reader.distributed_batch_reader(
batch_generator)
if args.validation_file:
val_processor = reader.DataProcessor(
fpattern=args.validation_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
device_count=trainer_count,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
val_batch_generator = val_processor.data_generator(phase="train")
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary()
with guard:
# set seed for CE
random_seed = eval(str(args.random_seed))
if random_seed is not None:
fluid.default_main_program().random_seed = random_seed
fluid.default_startup_program().random_seed = random_seed
# define data loader
train_loader = batch_generator
if args.validation_file:
val_loader = val_batch_generator
# define model
inputs = [
Input(
[None, None], "int64", name="src_word"),
Input(
[None, None], "int64", name="src_pos"),
Input(
[None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"),
Input(
[None, None], "int64", name="trg_word"),
Input(
[None, None], "int64", name="trg_pos"),
Input(
[None, args.n_head, None, None],
"float32",
name="trg_slf_attn_bias"),
Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias"),
]
labels = [
Input(
[None, 1], "int64", name="label"),
Input(
[None, 1], "float32", name="weight"),
]
transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
args.d_inner_hid, args.prepostprocess_dropout,
args.attention_dropout, args.relu_dropout, args.preprocess_cmd,
args.postprocess_cmd, args.weight_sharing, args.bos_idx,
args.eos_idx)
transformer.prepare(
fluid.optimizer.Adam(
learning_rate=fluid.layers.noam_decay(
args.d_model, args.warmup_steps), # args.learning_rate),
beta1=args.beta1,
beta2=args.beta2,
epsilon=float(args.eps),
parameter_list=transformer.parameters()),
CrossEntropyCriterion(args.label_smooth_eps),
inputs=inputs,
labels=labels)
## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint:
transformer.load(
os.path.join(args.init_from_checkpoint, "transformer"))
## init from some pretrain models, to better solve the current task
if args.init_from_pretrain_model:
transformer.load(
os.path.join(args.init_from_pretrain_model, "transformer"),
reset_optimizer=True)
# the best cross-entropy value with label smoothing
loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
transformer.fit(train_loader=train_loader,
eval_loader=val_loader,
epochs=1,
eval_freq=1,
save_freq=1,
verbose=2,
callbacks=[
LoggerCallback(
log_freq=args.print_step,
loss_normalizer=loss_normalizer)
])
max_length=args.max_length)
train_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=None,
feed_list=[x.forward() for x in inputs + labels],
num_workers=0,
return_list=True)
transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
args.d_inner_hid, args.prepostprocess_dropout, args.attention_dropout,
args.relu_dropout, args.preprocess_cmd, args.postprocess_cmd,
args.weight_sharing, args.bos_idx, args.eos_idx)
transformer.prepare(
fluid.optimizer.Adam(
learning_rate=fluid.layers.noam_decay(
args.d_model, args.warmup_steps), # args.learning_rate),
beta1=args.beta1,
beta2=args.beta2,
epsilon=float(args.eps),
parameter_list=transformer.parameters()),
CrossEntropyCriterion(args.label_smooth_eps),
inputs=inputs,
labels=labels)
## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint:
transformer.load(
os.path.join(args.init_from_checkpoint, "transformer"))
## init from some pretrain models, to better solve the current task
if args.init_from_pretrain_model:
transformer.load(
os.path.join(args.init_from_pretrain_model, "transformer"),
reset_optimizer=True)
# the best cross-entropy value with label smoothing
loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
transformer.fit(train_loader=train_loader,
eval_loader=None,
epochs=1,
eval_freq=1,
save_freq=1,
verbose=2,
callbacks=[
LoggerCallback(
log_freq=args.print_step,
loss_normalizer=loss_normalizer)
])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册