提交 6ef54e8e 编写于 作者: G guosheng

Refine Transformer by following comments and fix the target self attention bias in inference.

上级 ff80721e
......@@ -3,34 +3,36 @@ class TrainTaskConfig(object):
# the epoch number to train.
pass_num = 2
# number of sequences contained in a mini-batch.
# the number of sequences contained in a mini-batch.
batch_size = 64
# the hyper params for Adam optimizer.
# the hyper parameters for Adam optimizer.
learning_rate = 0.001
beta1 = 0.9
beta2 = 0.98
eps = 1e-9
# the params for learning rate scheduling
# the parameters for learning rate scheduling.
warmup_steps = 4000
# the directory for saving inference models
model_dir = "transformer_model"
# the directory for saving trained models.
model_dir = "trained_models"
class InferTaskConfig(object):
use_gpu = False
# number of sequences contained in a mini-batch
# the number of examples in one run for sequence generation.
# currently the batch size can only be set to 1.
batch_size = 1
# the params for beam search
# the parameters for beam search.
beam_size = 5
max_length = 30
# the number of decoded sentences to output.
n_best = 1
# the directory for loading inference model
model_path = "transformer_model/pass_1.infer.model"
# the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model"
class ModelHyperParams(object):
......
......@@ -66,12 +66,19 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64")
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[
-1], enc_in_data[-2], 1
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len,
trg_max_len))
trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape(
[-1, 1, trg_max_len, trg_max_len])
trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) *
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :],
[beam_size, 1, trg_max_len, 1])
enc_output = np.tile(enc_output, [beam_size, 1, 1])
# No need for trg_slf_attn_bias because of no paddings.
return trg_words, trg_pos, None, trg_src_attn_bias, enc_output
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output
def update_dec_in_data(dec_in_data, next_ids, active_beams):
"""
......@@ -79,6 +86,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
input data and dropping the finished instance beams.
"""
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data
trg_cur_len = len(next_ids[0]) + 1 # include the <bos>
trg_words = np.array(
[
beam_backtrace(
......@@ -88,14 +96,22 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
dtype="int64")
trg_words = trg_words.reshape([-1, 1])
trg_pos = np.array(
[range(1, len(next_ids[0]) + 2)] * len(active_beams) * beam_size,
[range(1, trg_cur_len + 1)] * len(active_beams) * beam_size,
dtype="int64").reshape([-1, 1])
active_beams_indice = (
(np.array(active_beams) * beam_size)[:, np.newaxis] +
np.array(range(beam_size))[np.newaxis, :]).flatten()
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((len(active_beams) * beam_size, trg_cur_len,
trg_cur_len))
trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape(
[-1, 1, trg_cur_len, trg_cur_len])
trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) *
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, len(next_ids[0]) + 1, 1])
[1, 1, trg_cur_len, 1])
enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output
......@@ -103,9 +119,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
enc_output)
for i in range(max_length):
predict_all = exe.run(decoder,
feed=dict(
filter(lambda item: item[1] is not None,
zip(dec_in_names, dec_in_data))),
feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0]
predict_all = np.log(predict_all)
predict_all = (
......@@ -206,9 +220,9 @@ def main():
encoder_input_data_names, [enc_output.name], decoder_program,
decoder_input_data_names, [predict.name], InferTaskConfig.beam_size,
InferTaskConfig.max_length, InferTaskConfig.n_best,
InferTaskConfig.batch_size, ModelHyperParams.n_head,
ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx,
ModelHyperParams.bos_idx, ModelHyperParams.eos_idx)
len(data), ModelHyperParams.n_head, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx)
for i in range(len(batch_seqs)):
seqs = batch_seqs[i]
scores = batch_scores[i]
......
......@@ -283,8 +283,15 @@ def encoder(enc_input,
encoder_layer.
"""
for i in range(n_layer):
enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
d_model, d_inner_hid, dropout_rate)
enc_output = encoder_layer(
enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate, )
enc_input = enc_output
return enc_output
......@@ -381,9 +388,10 @@ def make_inputs(input_data_names,
d_model,
batch_size,
max_length,
is_pos,
slf_attn_bias_flag,
src_attn_bias_flag,
pos_flag=1):
enc_output_flag=False):
"""
Define the input data layers for the transformer model.
"""
......@@ -391,35 +399,43 @@ def make_inputs(input_data_names,
# The shapes here act as placeholder.
# The shapes set here is to pass the infer-shape in compile time.
word = layers.data(
name=input_data_names[0],
name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1],
dtype="int64",
append_batch_size=False)
input_layers += [word]
# This is used for position data or label weight.
pos = layers.data(
name=input_data_names[1],
name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1],
dtype="int64" if pos_flag else "float32",
dtype="int64" if is_pos else "float32",
append_batch_size=False)
input_layers += [pos]
if slf_attn_bias_flag:
# This is used for attention bias or encoder output.
# This input is used to remove attention weights on paddings for the
# encoder and to remove attention weights on subsequent words for the
# decoder.
slf_attn_bias = layers.data(
name=input_data_names[2]
if slf_attn_bias_flag == 1 else input_data_names[-1],
shape=[batch_size, n_head, max_length, max_length]
if slf_attn_bias_flag == 1 else [batch_size, max_length, d_model],
name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length],
dtype="float32",
append_batch_size=False)
input_layers += [slf_attn_bias]
if src_attn_bias_flag:
# This input is used to remove attention weights on paddings.
src_attn_bias = layers.data(
name=input_data_names[3],
name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length],
dtype="float32",
append_batch_size=False)
input_layers += [src_attn_bias]
if enc_output_flag:
enc_output = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, max_length, d_model],
dtype="float32",
append_batch_size=False)
input_layers += [enc_output]
return input_layers
......@@ -438,7 +454,7 @@ def transformer(
trg_pad_idx,
pos_pad_idx, ):
enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model,
batch_size, max_length, 1, 0)
batch_size, max_length, True, True, False)
enc_output = wrap_encoder(
src_vocab_size,
......@@ -455,7 +471,7 @@ def transformer(
enc_input_layers, )
dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model,
batch_size, max_length, 1, 1)
batch_size, max_length, True, True, True)
predict = wrap_decoder(
trg_vocab_size,
......@@ -475,7 +491,7 @@ def transformer(
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size,
max_length, 0, 0, 0)
max_length, False, False, False)
cost = layers.cross_entropy(input=predict, label=gold)
weighted_cost = cost * weights
return layers.reduce_sum(weighted_cost), predict
......@@ -500,7 +516,7 @@ def wrap_encoder(src_vocab_size,
# This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias = make_inputs(
encoder_input_data_names, n_head, d_model, batch_size, max_length,
True, False)
True, True, False)
else:
src_word, src_pos, src_slf_attn_bias = enc_input_layers
enc_input = prepare_encoder(
......@@ -542,11 +558,9 @@ def wrap_decoder(trg_vocab_size,
"""
if dec_input_layers is None:
# This is used to implement independent decoder program in inference.
# No need for trg_slf_attn_bias because of no paddings in inference.
trg_word, trg_pos, enc_output, trg_src_attn_bias = make_inputs(
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = make_inputs(
decoder_input_data_names, n_head, d_model, batch_size, max_length,
2, 1)
trg_slf_attn_bias = None
True, True, True, True)
else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_input_layers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册