提交 57365421 编写于 作者: G guosheng

Update Transformer

上级 0b93f490
......@@ -17,20 +17,20 @@ import os
import six
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import time
import contextlib
from functools import partial
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.io import DataLoader
from paddle.fluid.layers.utils import flatten
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
# include task-specific libs
import reader
from model import Input, set_device
from reader import prepare_infer_input, Seq2SeqDataset, Seq2SeqBatchSampler
from transformer import InferTransformer, position_encoding_init
from model import Input
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
......@@ -51,98 +51,86 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
def do_predict(args):
@contextlib.contextmanager
def null_guard():
yield
guard = fluid.dygraph.guard() if args.eager_run else null_guard()
# define the data generator
processor = reader.DataProcessor(
fpattern=args.predict_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=False,
batch_size=args.batch_size,
device_count=1,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
batch_generator = processor.data_generator(phase="predict")
device = set_device("gpu" if args.use_cuda else "cpu")
fluid.enable_dygraph(device) if args.eager_run else None
inputs = [
Input([None, None], "int64", name="src_word"),
Input([None, None], "int64", name="src_pos"),
Input([None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"),
Input([None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias"),
]
# define data
dataset = Seq2SeqDataset(fpattern=args.predict_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary()
trg_idx2word = reader.DataProcessor.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
with guard:
# define data loader
test_loader = batch_generator
# define model
inputs = [
Input(
[None, None], "int64", name="src_word"),
Input(
[None, None], "int64", name="src_pos"),
Input(
[None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"),
Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias"),
]
transformer = InferTransformer(
args.src_vocab_size,
args.trg_vocab_size,
args.max_length + 1,
args.n_layer,
args.n_head,
args.d_key,
args.d_value,
args.d_model,
args.d_inner_hid,
args.prepostprocess_dropout,
args.attention_dropout,
args.relu_dropout,
args.preprocess_cmd,
args.postprocess_cmd,
args.weight_sharing,
args.bos_idx,
args.eos_idx,
beam_size=args.beam_size,
max_out_len=args.max_out_len)
transformer.prepare(inputs=inputs)
# load the trained model
assert args.init_from_params, (
"Please set init_from_params to load the infer model.")
transformer.load(os.path.join(args.init_from_params, "transformer"))
f = open(args.output_file, "wb")
for input_data in test_loader():
(src_word, src_pos, src_slf_attn_bias, trg_word,
trg_src_attn_bias) = input_data
finished_seq = transformer.test(inputs=(
src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias))[0]
finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq:
for beam_idx, beam in enumerate(ins):
if beam_idx >= args.n_best: break
id_list = post_process_seq(beam, args.bos_idx,
args.eos_idx)
word_list = [trg_idx2word[id] for id in id_list]
sequence = b" ".join(word_list) + b"\n"
f.write(sequence)
break
args.unk_idx = dataset.get_vocab_summary()
trg_idx2word = Seq2SeqDataset.load_dict(dict_path=args.trg_vocab_fpath,
reverse=True)
batch_sampler = Seq2SeqBatchSampler(dataset=dataset,
use_token_batch=False,
batch_size=args.batch_size,
max_length=args.max_length)
data_loader = DataLoader(dataset=dataset,
batch_sampler=batch_sampler,
places=device,
feed_list=[x.forward() for x in inputs],
collate_fn=partial(prepare_infer_input,
src_pad_idx=args.eos_idx,
n_head=args.n_head),
num_workers=0,
return_list=True)
# define model
transformer = InferTransformer(args.src_vocab_size,
args.trg_vocab_size,
args.max_length + 1,
args.n_layer,
args.n_head,
args.d_key,
args.d_value,
args.d_model,
args.d_inner_hid,
args.prepostprocess_dropout,
args.attention_dropout,
args.relu_dropout,
args.preprocess_cmd,
args.postprocess_cmd,
args.weight_sharing,
args.bos_idx,
args.eos_idx,
beam_size=args.beam_size,
max_out_len=args.max_out_len)
transformer.prepare(inputs=inputs)
# load the trained model
assert args.init_from_params, (
"Please set init_from_params to load the infer model.")
transformer.load(os.path.join(args.init_from_params, "transformer"))
# TODO: use model.predict when support variant length
f = open(args.output_file, "wb")
for data in data_loader():
finished_seq = transformer.test(inputs=flatten(data))[0]
finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq:
for beam_idx, beam in enumerate(ins):
if beam_idx >= args.n_best: break
id_list = post_process_seq(beam, args.bos_idx,
args.eos_idx)
word_list = [trg_idx2word[id] for id in id_list]
sequence = b" ".join(word_list) + b"\n"
f.write(sequence)
if __name__ == "__main__":
......
......@@ -60,22 +60,19 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
return data_inputs
def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head):
def prepare_infer_input(insts, src_pad_idx, n_head):
"""
Put all padded data needed by beam search decoder into a list.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
# start tokens
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
trg_word = trg_word.reshape(-1, 1)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias
src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias
]
return data_inputs
......@@ -343,11 +340,11 @@ class Seq2SeqBatchSampler(BatchSampler):
def __init__(self,
dataset,
batch_size,
pool_size,
sort_type=SortType.GLOBAL,
pool_size=10000,
sort_type=SortType.NONE,
min_length=0,
max_length=100,
shuffle=True,
shuffle=False,
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
......@@ -412,7 +409,7 @@ class Seq2SeqBatchSampler(BatchSampler):
batch[self._batch_size * i:self._batch_size * (i + 1)]
for i in range(self._nranks)
] for batch in batches]
batches = itertools.chain.from_iterable(batches)
batches = list(itertools.chain.from_iterable(batches))
# for multi-device
for batch_id, batch in enumerate(batches):
......
......@@ -17,8 +17,6 @@ import os
import six
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import time
import contextlib
from functools import partial
import numpy as np
......@@ -30,11 +28,10 @@ from paddle.fluid.io import DataLoader
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
# include task-specific libs
from reader import prepare_train_input, Seq2SeqDataset, Seq2SeqBatchSampler
from transformer import Transformer, CrossEntropyCriterion, NoamDecay
from model import Input, set_device
from callbacks import ProgBarLogger
from reader import prepare_train_input, Seq2SeqDataset, Seq2SeqBatchSampler
from transformer import Transformer, CrossEntropyCriterion, NoamDecay
class LoggerCallback(ProgBarLogger):
......@@ -72,7 +69,7 @@ def do_train(args):
fluid.default_main_program().random_seed = random_seed
fluid.default_startup_program().random_seed = random_seed
# define model
# define inputs
inputs = [
Input([None, None], "int64", name="src_word"),
Input([None, None], "int64", name="src_pos"),
......@@ -95,35 +92,42 @@ def do_train(args):
[None, 1], "float32", name="weight"),
]
dataset = Seq2SeqDataset(fpattern=args.training_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(dataset=dataset,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
max_length=args.max_length)
train_loader = DataLoader(dataset=dataset,
batch_sampler=batch_sampler,
places=device,
feed_list=[x.forward() for x in inputs + labels],
collate_fn=partial(prepare_train_input,
src_pad_idx=args.eos_idx,
trg_pad_idx=args.eos_idx,
n_head=args.n_head),
num_workers=0,
return_list=True)
# def dataloader
data_loaders = [None, None]
data_files = [args.training_file, args.validation_file
] if args.validation_file else [args.training_file]
for i, data_file in enumerate(data_files):
dataset = Seq2SeqDataset(fpattern=data_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(dataset=dataset,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
max_length=args.max_length)
data_loader = DataLoader(dataset=dataset,
batch_sampler=batch_sampler,
places=device,
feed_list=[x.forward() for x in inputs + labels],
collate_fn=partial(prepare_train_input,
src_pad_idx=args.eos_idx,
trg_pad_idx=args.eos_idx,
n_head=args.n_head),
num_workers=0,
return_list=True)
data_loaders[i] = data_loader
train_loader, eval_loader = data_loaders
# define model
transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
......@@ -131,17 +135,15 @@ def do_train(args):
args.relu_dropout, args.preprocess_cmd, args.postprocess_cmd,
args.weight_sharing, args.bos_idx, args.eos_idx)
transformer.prepare(
fluid.optimizer.Adam(
learning_rate=fluid.layers.noam_decay(
args.d_model, args.warmup_steps), # args.learning_rate),
beta1=args.beta1,
beta2=args.beta2,
epsilon=float(args.eps),
parameter_list=transformer.parameters()),
CrossEntropyCriterion(args.label_smooth_eps),
inputs=inputs,
labels=labels)
transformer.prepare(fluid.optimizer.Adam(
learning_rate=fluid.layers.noam_decay(args.d_model, args.warmup_steps),
beta1=args.beta1,
beta2=args.beta2,
epsilon=float(args.eps),
parameter_list=transformer.parameters()),
CrossEntropyCriterion(args.label_smooth_eps),
inputs=inputs,
labels=labels)
## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint:
......@@ -159,8 +161,9 @@ def do_train(args):
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
# model train
transformer.fit(train_data=train_loader,
eval_data=None,
eval_data=eval_loader,
epochs=1,
eval_freq=1,
save_freq=1,
......
......@@ -652,8 +652,9 @@ class InferTransformer(Transformer):
eos_id=1,
beam_size=4,
max_out_len=256):
args = locals()
args = dict(locals())
args.pop("self")
args.pop("__class__", None) # py3
self.beam_size = args.pop("beam_size")
self.max_out_len = args.pop("max_out_len")
super(InferTransformer, self).__init__(**args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册