提交 6ef54e8e 编写于 作者: G guosheng

Refine Transformer by following comments and fix the target self attention bias in inference.

上级 ff80721e
...@@ -3,34 +3,36 @@ class TrainTaskConfig(object): ...@@ -3,34 +3,36 @@ class TrainTaskConfig(object):
# the epoch number to train. # the epoch number to train.
pass_num = 2 pass_num = 2
# number of sequences contained in a mini-batch. # the number of sequences contained in a mini-batch.
batch_size = 64 batch_size = 64
# the hyper params for Adam optimizer. # the hyper parameters for Adam optimizer.
learning_rate = 0.001 learning_rate = 0.001
beta1 = 0.9 beta1 = 0.9
beta2 = 0.98 beta2 = 0.98
eps = 1e-9 eps = 1e-9
# the params for learning rate scheduling # the parameters for learning rate scheduling.
warmup_steps = 4000 warmup_steps = 4000
# the directory for saving inference models # the directory for saving trained models.
model_dir = "transformer_model" model_dir = "trained_models"
class InferTaskConfig(object): class InferTaskConfig(object):
use_gpu = False use_gpu = False
# number of sequences contained in a mini-batch # the number of examples in one run for sequence generation.
# currently the batch size can only be set to 1.
batch_size = 1 batch_size = 1
# the params for beam search # the parameters for beam search.
beam_size = 5 beam_size = 5
max_length = 30 max_length = 30
# the number of decoded sentences to output.
n_best = 1 n_best = 1
# the directory for loading inference model # the directory for loading the trained model.
model_path = "transformer_model/pass_1.infer.model" model_path = "trained_models/pass_1.infer.model"
class ModelHyperParams(object): class ModelHyperParams(object):
......
...@@ -66,12 +66,19 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -66,12 +66,19 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64") trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64")
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[ src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[
-1], enc_in_data[-2], 1 -1], enc_in_data[-2], 1
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len,
trg_max_len))
trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape(
[-1, 1, trg_max_len, trg_max_len])
trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) *
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile( trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :], src_slf_attn_bias[:, :, ::src_max_length, :],
[beam_size, 1, trg_max_len, 1]) [beam_size, 1, trg_max_len, 1])
enc_output = np.tile(enc_output, [beam_size, 1, 1]) enc_output = np.tile(enc_output, [beam_size, 1, 1])
# No need for trg_slf_attn_bias because of no paddings. return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output
return trg_words, trg_pos, None, trg_src_attn_bias, enc_output
def update_dec_in_data(dec_in_data, next_ids, active_beams): def update_dec_in_data(dec_in_data, next_ids, active_beams):
""" """
...@@ -79,6 +86,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -79,6 +86,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
input data and dropping the finished instance beams. input data and dropping the finished instance beams.
""" """
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data
trg_cur_len = len(next_ids[0]) + 1 # include the <bos>
trg_words = np.array( trg_words = np.array(
[ [
beam_backtrace( beam_backtrace(
...@@ -88,14 +96,22 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -88,14 +96,22 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
dtype="int64") dtype="int64")
trg_words = trg_words.reshape([-1, 1]) trg_words = trg_words.reshape([-1, 1])
trg_pos = np.array( trg_pos = np.array(
[range(1, len(next_ids[0]) + 2)] * len(active_beams) * beam_size, [range(1, trg_cur_len + 1)] * len(active_beams) * beam_size,
dtype="int64").reshape([-1, 1]) dtype="int64").reshape([-1, 1])
active_beams_indice = ( active_beams_indice = (
(np.array(active_beams) * beam_size)[:, np.newaxis] + (np.array(active_beams) * beam_size)[:, np.newaxis] +
np.array(range(beam_size))[np.newaxis, :]).flatten() np.array(range(beam_size))[np.newaxis, :]).flatten()
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((len(active_beams) * beam_size, trg_cur_len,
trg_cur_len))
trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape(
[-1, 1, trg_cur_len, trg_cur_len])
trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) *
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(trg_src_attn_bias[ trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :], active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, len(next_ids[0]) + 1, 1]) [1, 1, trg_cur_len, 1])
enc_output = enc_output[active_beams_indice, :, :] enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output
...@@ -103,9 +119,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -103,9 +119,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
enc_output) enc_output)
for i in range(max_length): for i in range(max_length):
predict_all = exe.run(decoder, predict_all = exe.run(decoder,
feed=dict( feed=dict(zip(dec_in_names, dec_in_data)),
filter(lambda item: item[1] is not None,
zip(dec_in_names, dec_in_data))),
fetch_list=dec_out_names)[0] fetch_list=dec_out_names)[0]
predict_all = np.log(predict_all) predict_all = np.log(predict_all)
predict_all = ( predict_all = (
...@@ -206,9 +220,9 @@ def main(): ...@@ -206,9 +220,9 @@ def main():
encoder_input_data_names, [enc_output.name], decoder_program, encoder_input_data_names, [enc_output.name], decoder_program,
decoder_input_data_names, [predict.name], InferTaskConfig.beam_size, decoder_input_data_names, [predict.name], InferTaskConfig.beam_size,
InferTaskConfig.max_length, InferTaskConfig.n_best, InferTaskConfig.max_length, InferTaskConfig.n_best,
InferTaskConfig.batch_size, ModelHyperParams.n_head, len(data), ModelHyperParams.n_head, ModelHyperParams.src_pad_idx,
ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx, ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx,
ModelHyperParams.bos_idx, ModelHyperParams.eos_idx) ModelHyperParams.eos_idx)
for i in range(len(batch_seqs)): for i in range(len(batch_seqs)):
seqs = batch_seqs[i] seqs = batch_seqs[i]
scores = batch_scores[i] scores = batch_scores[i]
......
...@@ -283,8 +283,15 @@ def encoder(enc_input, ...@@ -283,8 +283,15 @@ def encoder(enc_input,
encoder_layer. encoder_layer.
""" """
for i in range(n_layer): for i in range(n_layer):
enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value, enc_output = encoder_layer(
d_model, d_inner_hid, dropout_rate) enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate, )
enc_input = enc_output enc_input = enc_output
return enc_output return enc_output
...@@ -381,9 +388,10 @@ def make_inputs(input_data_names, ...@@ -381,9 +388,10 @@ def make_inputs(input_data_names,
d_model, d_model,
batch_size, batch_size,
max_length, max_length,
is_pos,
slf_attn_bias_flag, slf_attn_bias_flag,
src_attn_bias_flag, src_attn_bias_flag,
pos_flag=1): enc_output_flag=False):
""" """
Define the input data layers for the transformer model. Define the input data layers for the transformer model.
""" """
...@@ -391,35 +399,43 @@ def make_inputs(input_data_names, ...@@ -391,35 +399,43 @@ def make_inputs(input_data_names,
# The shapes here act as placeholder. # The shapes here act as placeholder.
# The shapes set here is to pass the infer-shape in compile time. # The shapes set here is to pass the infer-shape in compile time.
word = layers.data( word = layers.data(
name=input_data_names[0], name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1], shape=[batch_size * max_length, 1],
dtype="int64", dtype="int64",
append_batch_size=False) append_batch_size=False)
input_layers += [word] input_layers += [word]
# This is used for position data or label weight. # This is used for position data or label weight.
pos = layers.data( pos = layers.data(
name=input_data_names[1], name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1], shape=[batch_size * max_length, 1],
dtype="int64" if pos_flag else "float32", dtype="int64" if is_pos else "float32",
append_batch_size=False) append_batch_size=False)
input_layers += [pos] input_layers += [pos]
if slf_attn_bias_flag: if slf_attn_bias_flag:
# This is used for attention bias or encoder output. # This input is used to remove attention weights on paddings for the
# encoder and to remove attention weights on subsequent words for the
# decoder.
slf_attn_bias = layers.data( slf_attn_bias = layers.data(
name=input_data_names[2] name=input_data_names[len(input_layers)],
if slf_attn_bias_flag == 1 else input_data_names[-1], shape=[batch_size, n_head, max_length, max_length],
shape=[batch_size, n_head, max_length, max_length]
if slf_attn_bias_flag == 1 else [batch_size, max_length, d_model],
dtype="float32", dtype="float32",
append_batch_size=False) append_batch_size=False)
input_layers += [slf_attn_bias] input_layers += [slf_attn_bias]
if src_attn_bias_flag: if src_attn_bias_flag:
# This input is used to remove attention weights on paddings.
src_attn_bias = layers.data( src_attn_bias = layers.data(
name=input_data_names[3], name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length], shape=[batch_size, n_head, max_length, max_length],
dtype="float32", dtype="float32",
append_batch_size=False) append_batch_size=False)
input_layers += [src_attn_bias] input_layers += [src_attn_bias]
if enc_output_flag:
enc_output = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, max_length, d_model],
dtype="float32",
append_batch_size=False)
input_layers += [enc_output]
return input_layers return input_layers
...@@ -438,7 +454,7 @@ def transformer( ...@@ -438,7 +454,7 @@ def transformer(
trg_pad_idx, trg_pad_idx,
pos_pad_idx, ): pos_pad_idx, ):
enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model, enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model,
batch_size, max_length, 1, 0) batch_size, max_length, True, True, False)
enc_output = wrap_encoder( enc_output = wrap_encoder(
src_vocab_size, src_vocab_size,
...@@ -455,7 +471,7 @@ def transformer( ...@@ -455,7 +471,7 @@ def transformer(
enc_input_layers, ) enc_input_layers, )
dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model, dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model,
batch_size, max_length, 1, 1) batch_size, max_length, True, True, True)
predict = wrap_decoder( predict = wrap_decoder(
trg_vocab_size, trg_vocab_size,
...@@ -475,7 +491,7 @@ def transformer( ...@@ -475,7 +491,7 @@ def transformer(
# Padding index do not contribute to the total loss. The weights is used to # Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss. # cancel padding index in calculating the loss.
gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size, gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size,
max_length, 0, 0, 0) max_length, False, False, False)
cost = layers.cross_entropy(input=predict, label=gold) cost = layers.cross_entropy(input=predict, label=gold)
weighted_cost = cost * weights weighted_cost = cost * weights
return layers.reduce_sum(weighted_cost), predict return layers.reduce_sum(weighted_cost), predict
...@@ -500,7 +516,7 @@ def wrap_encoder(src_vocab_size, ...@@ -500,7 +516,7 @@ def wrap_encoder(src_vocab_size,
# This is used to implement independent encoder program in inference. # This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias = make_inputs( src_word, src_pos, src_slf_attn_bias = make_inputs(
encoder_input_data_names, n_head, d_model, batch_size, max_length, encoder_input_data_names, n_head, d_model, batch_size, max_length,
True, False) True, True, False)
else: else:
src_word, src_pos, src_slf_attn_bias = enc_input_layers src_word, src_pos, src_slf_attn_bias = enc_input_layers
enc_input = prepare_encoder( enc_input = prepare_encoder(
...@@ -542,11 +558,9 @@ def wrap_decoder(trg_vocab_size, ...@@ -542,11 +558,9 @@ def wrap_decoder(trg_vocab_size,
""" """
if dec_input_layers is None: if dec_input_layers is None:
# This is used to implement independent decoder program in inference. # This is used to implement independent decoder program in inference.
# No need for trg_slf_attn_bias because of no paddings in inference. trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = make_inputs(
trg_word, trg_pos, enc_output, trg_src_attn_bias = make_inputs(
decoder_input_data_names, n_head, d_model, batch_size, max_length, decoder_input_data_names, n_head, d_model, batch_size, max_length,
2, 1) True, True, True, True)
trg_slf_attn_bias = None
else: else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_input_layers trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_input_layers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册