提交 57365421 编写于 作者: G guosheng

Update Transformer

上级 0b93f490
...@@ -17,20 +17,20 @@ import os ...@@ -17,20 +17,20 @@ import os
import six import six
import sys import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import time from functools import partial
import contextlib
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.io import DataLoader
from paddle.fluid.layers.utils import flatten
from utils.configure import PDConfig from utils.configure import PDConfig
from utils.check import check_gpu, check_version from utils.check import check_gpu, check_version
# include task-specific libs from model import Input, set_device
import reader from reader import prepare_infer_input, Seq2SeqDataset, Seq2SeqBatchSampler
from transformer import InferTransformer, position_encoding_init from transformer import InferTransformer, position_encoding_init
from model import Input
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
...@@ -51,98 +51,86 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, ...@@ -51,98 +51,86 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
def do_predict(args): def do_predict(args):
@contextlib.contextmanager device = set_device("gpu" if args.use_cuda else "cpu")
def null_guard(): fluid.enable_dygraph(device) if args.eager_run else None
yield
inputs = [
guard = fluid.dygraph.guard() if args.eager_run else null_guard() Input([None, None], "int64", name="src_word"),
Input([None, None], "int64", name="src_pos"),
# define the data generator Input([None, args.n_head, None, None],
processor = reader.DataProcessor( "float32",
fpattern=args.predict_file, name="src_slf_attn_bias"),
src_vocab_fpath=args.src_vocab_fpath, Input([None, args.n_head, None, None],
trg_vocab_fpath=args.trg_vocab_fpath, "float32",
token_delimiter=args.token_delimiter, name="trg_src_attn_bias"),
use_token_batch=False, ]
batch_size=args.batch_size,
device_count=1, # define data
pool_size=args.pool_size, dataset = Seq2SeqDataset(fpattern=args.predict_file,
sort_type=reader.SortType.NONE, src_vocab_fpath=args.src_vocab_fpath,
shuffle=False, trg_vocab_fpath=args.trg_vocab_fpath,
shuffle_batch=False, token_delimiter=args.token_delimiter,
start_mark=args.special_token[0], start_mark=args.special_token[0],
end_mark=args.special_token[1], end_mark=args.special_token[1],
unk_mark=args.special_token[2], unk_mark=args.special_token[2])
max_length=args.max_length,
n_head=args.n_head)
batch_generator = processor.data_generator(phase="predict")
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary() args.unk_idx = dataset.get_vocab_summary()
trg_idx2word = reader.DataProcessor.load_dict( trg_idx2word = Seq2SeqDataset.load_dict(dict_path=args.trg_vocab_fpath,
dict_path=args.trg_vocab_fpath, reverse=True) reverse=True)
batch_sampler = Seq2SeqBatchSampler(dataset=dataset,
with guard: use_token_batch=False,
# define data loader batch_size=args.batch_size,
test_loader = batch_generator max_length=args.max_length)
data_loader = DataLoader(dataset=dataset,
# define model batch_sampler=batch_sampler,
inputs = [ places=device,
Input( feed_list=[x.forward() for x in inputs],
[None, None], "int64", name="src_word"), collate_fn=partial(prepare_infer_input,
Input( src_pad_idx=args.eos_idx,
[None, None], "int64", name="src_pos"), n_head=args.n_head),
Input( num_workers=0,
[None, args.n_head, None, None], return_list=True)
"float32",
name="src_slf_attn_bias"), # define model
Input( transformer = InferTransformer(args.src_vocab_size,
[None, args.n_head, None, None], args.trg_vocab_size,
"float32", args.max_length + 1,
name="trg_src_attn_bias"), args.n_layer,
] args.n_head,
transformer = InferTransformer( args.d_key,
args.src_vocab_size, args.d_value,
args.trg_vocab_size, args.d_model,
args.max_length + 1, args.d_inner_hid,
args.n_layer, args.prepostprocess_dropout,
args.n_head, args.attention_dropout,
args.d_key, args.relu_dropout,
args.d_value, args.preprocess_cmd,
args.d_model, args.postprocess_cmd,
args.d_inner_hid, args.weight_sharing,
args.prepostprocess_dropout, args.bos_idx,
args.attention_dropout, args.eos_idx,
args.relu_dropout, beam_size=args.beam_size,
args.preprocess_cmd, max_out_len=args.max_out_len)
args.postprocess_cmd, transformer.prepare(inputs=inputs)
args.weight_sharing,
args.bos_idx, # load the trained model
args.eos_idx, assert args.init_from_params, (
beam_size=args.beam_size, "Please set init_from_params to load the infer model.")
max_out_len=args.max_out_len) transformer.load(os.path.join(args.init_from_params, "transformer"))
transformer.prepare(inputs=inputs)
# TODO: use model.predict when support variant length
# load the trained model f = open(args.output_file, "wb")
assert args.init_from_params, ( for data in data_loader():
"Please set init_from_params to load the infer model.") finished_seq = transformer.test(inputs=flatten(data))[0]
transformer.load(os.path.join(args.init_from_params, "transformer")) finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq:
f = open(args.output_file, "wb") for beam_idx, beam in enumerate(ins):
for input_data in test_loader(): if beam_idx >= args.n_best: break
(src_word, src_pos, src_slf_attn_bias, trg_word, id_list = post_process_seq(beam, args.bos_idx,
trg_src_attn_bias) = input_data args.eos_idx)
finished_seq = transformer.test(inputs=( word_list = [trg_idx2word[id] for id in id_list]
src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias))[0] sequence = b" ".join(word_list) + b"\n"
finished_seq = np.transpose(finished_seq, [0, 2, 1]) f.write(sequence)
for ins in finished_seq:
for beam_idx, beam in enumerate(ins):
if beam_idx >= args.n_best: break
id_list = post_process_seq(beam, args.bos_idx,
args.eos_idx)
word_list = [trg_idx2word[id] for id in id_list]
sequence = b" ".join(word_list) + b"\n"
f.write(sequence)
break
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -60,22 +60,19 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): ...@@ -60,22 +60,19 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
return data_inputs return data_inputs
def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head): def prepare_infer_input(insts, src_pad_idx, n_head):
""" """
Put all padded data needed by beam search decoder into a list. Put all padded data needed by beam search decoder into a list.
""" """
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
# start tokens
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32") [1, 1, 1, 1]).astype("float32")
trg_word = trg_word.reshape(-1, 1)
src_word = src_word.reshape(-1, src_max_len) src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len) src_pos = src_pos.reshape(-1, src_max_len)
data_inputs = [ data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias
] ]
return data_inputs return data_inputs
...@@ -343,11 +340,11 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -343,11 +340,11 @@ class Seq2SeqBatchSampler(BatchSampler):
def __init__(self, def __init__(self,
dataset, dataset,
batch_size, batch_size,
pool_size, pool_size=10000,
sort_type=SortType.GLOBAL, sort_type=SortType.NONE,
min_length=0, min_length=0,
max_length=100, max_length=100,
shuffle=True, shuffle=False,
shuffle_batch=False, shuffle_batch=False,
use_token_batch=False, use_token_batch=False,
clip_last_batch=False, clip_last_batch=False,
...@@ -412,7 +409,7 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -412,7 +409,7 @@ class Seq2SeqBatchSampler(BatchSampler):
batch[self._batch_size * i:self._batch_size * (i + 1)] batch[self._batch_size * i:self._batch_size * (i + 1)]
for i in range(self._nranks) for i in range(self._nranks)
] for batch in batches] ] for batch in batches]
batches = itertools.chain.from_iterable(batches) batches = list(itertools.chain.from_iterable(batches))
# for multi-device # for multi-device
for batch_id, batch in enumerate(batches): for batch_id, batch in enumerate(batches):
......
...@@ -17,8 +17,6 @@ import os ...@@ -17,8 +17,6 @@ import os
import six import six
import sys import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import time
import contextlib
from functools import partial from functools import partial
import numpy as np import numpy as np
...@@ -30,11 +28,10 @@ from paddle.fluid.io import DataLoader ...@@ -30,11 +28,10 @@ from paddle.fluid.io import DataLoader
from utils.configure import PDConfig from utils.configure import PDConfig
from utils.check import check_gpu, check_version from utils.check import check_gpu, check_version
# include task-specific libs
from reader import prepare_train_input, Seq2SeqDataset, Seq2SeqBatchSampler
from transformer import Transformer, CrossEntropyCriterion, NoamDecay
from model import Input, set_device from model import Input, set_device
from callbacks import ProgBarLogger from callbacks import ProgBarLogger
from reader import prepare_train_input, Seq2SeqDataset, Seq2SeqBatchSampler
from transformer import Transformer, CrossEntropyCriterion, NoamDecay
class LoggerCallback(ProgBarLogger): class LoggerCallback(ProgBarLogger):
...@@ -72,7 +69,7 @@ def do_train(args): ...@@ -72,7 +69,7 @@ def do_train(args):
fluid.default_main_program().random_seed = random_seed fluid.default_main_program().random_seed = random_seed
fluid.default_startup_program().random_seed = random_seed fluid.default_startup_program().random_seed = random_seed
# define model # define inputs
inputs = [ inputs = [
Input([None, None], "int64", name="src_word"), Input([None, None], "int64", name="src_word"),
Input([None, None], "int64", name="src_pos"), Input([None, None], "int64", name="src_pos"),
...@@ -95,35 +92,42 @@ def do_train(args): ...@@ -95,35 +92,42 @@ def do_train(args):
[None, 1], "float32", name="weight"), [None, 1], "float32", name="weight"),
] ]
dataset = Seq2SeqDataset(fpattern=args.training_file, # def dataloader
src_vocab_fpath=args.src_vocab_fpath, data_loaders = [None, None]
trg_vocab_fpath=args.trg_vocab_fpath, data_files = [args.training_file, args.validation_file
token_delimiter=args.token_delimiter, ] if args.validation_file else [args.training_file]
start_mark=args.special_token[0], for i, data_file in enumerate(data_files):
end_mark=args.special_token[1], dataset = Seq2SeqDataset(fpattern=data_file,
unk_mark=args.special_token[2]) src_vocab_fpath=args.src_vocab_fpath,
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ trg_vocab_fpath=args.trg_vocab_fpath,
args.unk_idx = dataset.get_vocab_summary() token_delimiter=args.token_delimiter,
batch_sampler = Seq2SeqBatchSampler(dataset=dataset, start_mark=args.special_token[0],
use_token_batch=args.use_token_batch, end_mark=args.special_token[1],
batch_size=args.batch_size, unk_mark=args.special_token[2])
pool_size=args.pool_size, args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
sort_type=args.sort_type, args.unk_idx = dataset.get_vocab_summary()
shuffle=args.shuffle, batch_sampler = Seq2SeqBatchSampler(dataset=dataset,
shuffle_batch=args.shuffle_batch, use_token_batch=args.use_token_batch,
max_length=args.max_length) batch_size=args.batch_size,
train_loader = DataLoader(dataset=dataset, pool_size=args.pool_size,
batch_sampler=batch_sampler, sort_type=args.sort_type,
places=device, shuffle=args.shuffle,
feed_list=[x.forward() for x in inputs + labels], shuffle_batch=args.shuffle_batch,
collate_fn=partial(prepare_train_input, max_length=args.max_length)
src_pad_idx=args.eos_idx, data_loader = DataLoader(dataset=dataset,
trg_pad_idx=args.eos_idx, batch_sampler=batch_sampler,
n_head=args.n_head), places=device,
num_workers=0, feed_list=[x.forward() for x in inputs + labels],
return_list=True) collate_fn=partial(prepare_train_input,
src_pad_idx=args.eos_idx,
trg_pad_idx=args.eos_idx,
n_head=args.n_head),
num_workers=0,
return_list=True)
data_loaders[i] = data_loader
train_loader, eval_loader = data_loaders
# define model
transformer = Transformer( transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1, args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model, args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
...@@ -131,17 +135,15 @@ def do_train(args): ...@@ -131,17 +135,15 @@ def do_train(args):
args.relu_dropout, args.preprocess_cmd, args.postprocess_cmd, args.relu_dropout, args.preprocess_cmd, args.postprocess_cmd,
args.weight_sharing, args.bos_idx, args.eos_idx) args.weight_sharing, args.bos_idx, args.eos_idx)
transformer.prepare( transformer.prepare(fluid.optimizer.Adam(
fluid.optimizer.Adam( learning_rate=fluid.layers.noam_decay(args.d_model, args.warmup_steps),
learning_rate=fluid.layers.noam_decay( beta1=args.beta1,
args.d_model, args.warmup_steps), # args.learning_rate), beta2=args.beta2,
beta1=args.beta1, epsilon=float(args.eps),
beta2=args.beta2, parameter_list=transformer.parameters()),
epsilon=float(args.eps), CrossEntropyCriterion(args.label_smooth_eps),
parameter_list=transformer.parameters()), inputs=inputs,
CrossEntropyCriterion(args.label_smooth_eps), labels=labels)
inputs=inputs,
labels=labels)
## init from some checkpoint, to resume the previous training ## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint: if args.init_from_checkpoint:
...@@ -159,8 +161,9 @@ def do_train(args): ...@@ -159,8 +161,9 @@ def do_train(args):
(1. - args.label_smooth_eps)) + args.label_smooth_eps * (1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20)) np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
# model train
transformer.fit(train_data=train_loader, transformer.fit(train_data=train_loader,
eval_data=None, eval_data=eval_loader,
epochs=1, epochs=1,
eval_freq=1, eval_freq=1,
save_freq=1, save_freq=1,
......
...@@ -652,8 +652,9 @@ class InferTransformer(Transformer): ...@@ -652,8 +652,9 @@ class InferTransformer(Transformer):
eos_id=1, eos_id=1,
beam_size=4, beam_size=4,
max_out_len=256): max_out_len=256):
args = locals() args = dict(locals())
args.pop("self") args.pop("self")
args.pop("__class__", None) # py3
self.beam_size = args.pop("beam_size") self.beam_size = args.pop("beam_size")
self.max_out_len = args.pop("max_out_len") self.max_out_len = args.pop("max_out_len")
super(InferTransformer, self).__init__(**args) super(InferTransformer, self).__init__(**args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册