提交 fcdd4178 编写于 作者: G guosheng

Make the Transformer network configurations more flexible

上级 0b48d785
......@@ -79,8 +79,14 @@ class ModelHyperParams(object):
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rate used by all dropout layers.
dropout = 0.1
# dropout rates of different modules.
prepostprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
# to process before each sub-layer
preprocess_cmd = "n" # layer normalization
# to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = True
......
......@@ -335,7 +335,10 @@ def py_infer(test_data, trg_idx2word):
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout, ModelHyperParams.weight_sharing)
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing)
decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
......@@ -344,7 +347,10 @@ def py_infer(test_data, trg_idx2word):
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout, ModelHyperParams.weight_sharing)
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing)
# Load model parameters of encoder and decoder separately from the saved
# transformer model.
......@@ -477,7 +483,9 @@ def fast_infer(test_data, trg_idx2word):
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.d_inner_hid, ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing, InferTaskConfig.beam_size,
InferTaskConfig.max_out_len, ModelHyperParams.eos_idx)
......
......@@ -37,6 +37,9 @@ def multi_head_attention(queries,
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
keys = queries if keys is None else keys
values = keys if values is None else values
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError(
"Inputs: quries, keys and values should all be 3-D tensors.")
......@@ -95,11 +98,11 @@ def multi_head_attention(queries,
x=trans_x,
shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
Scaled Dot-Product Attention
"""
scaled_q = layers.scale(x=q, scale=d_model**-0.5)
scaled_q = layers.scale(x=q, scale=d_key**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
weights = layers.reshape(
x=layers.elementwise_add(
......@@ -138,7 +141,7 @@ def multi_head_attention(queries,
return proj_out
def positionwise_feed_forward(x, d_inner_hid, d_hid):
def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
......@@ -148,6 +151,9 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid):
size=d_inner_hid,
num_flatten_dims=2,
act="relu")
if dropout_rate:
hidden = layers.dropout(
hidden, dropout_prob=dropout_rate, is_test=False)
out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
return out
......@@ -228,7 +234,11 @@ def encoder_layer(enc_input,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
pre_softmax_shape=None,
post_softmax_shape=None):
"""The encoder layers that can be stacked to form a deep encoder.
......@@ -238,12 +248,16 @@ def encoder_layer(enc_input,
and droput.
"""
attn_output = multi_head_attention(
enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model,
n_head, dropout_rate, pre_softmax_shape, post_softmax_shape)
attn_output = post_process_layer(enc_input, attn_output, "dan",
dropout_rate)
ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)
pre_process_layer(enc_input, preprocess_cmd, prepostprocess_dropout),
None, None, attn_bias, d_key, d_value, d_model, n_head,
attention_dropout, pre_softmax_shape, post_softmax_shape)
attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd,
prepostprocess_dropout)
ffd_output = positionwise_feed_forward(
pre_process_layer(attn_output, preprocess_cmd, prepostprocess_dropout),
d_inner_hid, d_model, relu_dropout)
return post_process_layer(attn_output, ffd_output, postprocess_cmd,
prepostprocess_dropout)
def encoder(enc_input,
......@@ -254,7 +268,11 @@ def encoder(enc_input,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
pre_softmax_shape=None,
post_softmax_shape=None):
"""
......@@ -270,10 +288,16 @@ def encoder(enc_input,
d_value,
d_model,
d_inner_hid,
dropout_rate,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
pre_softmax_shape,
post_softmax_shape, )
enc_input = enc_output
enc_output = pre_process_layer(enc_output, preprocess_cmd,
prepostprocess_dropout)
return enc_output
......@@ -286,7 +310,11 @@ def decoder_layer(dec_input,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
......@@ -297,25 +325,26 @@ def decoder_layer(dec_input,
a multi-head attention is added to implement encoder-decoder attention.
"""
slf_attn_output = multi_head_attention(
dec_input,
dec_input,
dec_input,
pre_process_layer(dec_input, preprocess_cmd, prepostprocess_dropout),
None,
None,
slf_attn_bias,
d_key,
d_value,
d_model,
n_head,
dropout_rate,
attention_dropout,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
cache, )
slf_attn_output = post_process_layer(
dec_input,
slf_attn_output,
"dan", # residual connection + dropout + layer normalization
dropout_rate, )
postprocess_cmd,
prepostprocess_dropout, )
enc_attn_output = multi_head_attention(
slf_attn_output,
pre_process_layer(slf_attn_output, preprocess_cmd,
prepostprocess_dropout),
enc_output,
enc_output,
dec_enc_attn_bias,
......@@ -323,23 +352,25 @@ def decoder_layer(dec_input,
d_value,
d_model,
n_head,
dropout_rate,
attention_dropout,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
enc_attn_output = post_process_layer(
slf_attn_output,
enc_attn_output,
"dan", # residual connection + dropout + layer normalization
dropout_rate, )
postprocess_cmd,
prepostprocess_dropout, )
ffd_output = positionwise_feed_forward(
enc_attn_output,
pre_process_layer(enc_attn_output, preprocess_cmd,
prepostprocess_dropout),
d_inner_hid,
d_model, )
d_model,
relu_dropout, )
dec_output = post_process_layer(
enc_attn_output,
ffd_output,
"dan", # residual connection + dropout + layer normalization
dropout_rate, )
postprocess_cmd,
prepostprocess_dropout, )
return dec_output
......@@ -353,7 +384,11 @@ def decoder(dec_input,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
......@@ -373,13 +408,19 @@ def decoder(dec_input,
d_value,
d_model,
d_inner_hid,
dropout_rate,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape,
None if caches is None else caches[i], )
dec_input = dec_output
dec_output = pre_process_layer(dec_output, preprocess_cmd,
prepostprocess_dropout)
return dec_output
......@@ -410,7 +451,11 @@ def transformer(
d_value,
d_model,
d_inner_hid,
dropout_rate,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
label_smooth_eps, ):
if weight_sharing:
......@@ -429,7 +474,11 @@ def transformer(
d_value,
d_model,
d_inner_hid,
dropout_rate,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
enc_inputs, )
......@@ -445,7 +494,11 @@ def transformer(
d_value,
d_model,
d_inner_hid,
dropout_rate,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
dec_inputs,
enc_output, )
......@@ -477,7 +530,11 @@ def wrap_encoder(src_vocab_size,
d_value,
d_model,
d_inner_hid,
dropout_rate,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
enc_inputs=None):
"""
......@@ -499,7 +556,7 @@ def wrap_encoder(src_vocab_size,
src_vocab_size,
d_model,
max_length,
dropout_rate,
prepostprocess_dropout,
src_data_shape,
word_emb_param_name=word_emb_param_names[0])
enc_output = encoder(
......@@ -511,7 +568,11 @@ def wrap_encoder(src_vocab_size,
d_value,
d_model,
d_inner_hid,
dropout_rate,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, )
return enc_output
......@@ -525,7 +586,11 @@ def wrap_decoder(trg_vocab_size,
d_value,
d_model,
d_inner_hid,
dropout_rate,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
dec_inputs=None,
enc_output=None,
......@@ -552,7 +617,7 @@ def wrap_decoder(trg_vocab_size,
trg_vocab_size,
d_model,
max_length,
dropout_rate,
prepostprocess_dropout,
trg_data_shape,
word_emb_param_name=word_emb_param_names[0]
if weight_sharing else word_emb_param_names[1])
......@@ -567,7 +632,11 @@ def wrap_decoder(trg_vocab_size,
d_value,
d_model,
d_inner_hid,
dropout_rate,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
......@@ -603,7 +672,11 @@ def fast_decode(
d_value,
d_model,
d_inner_hid,
dropout_rate,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
beam_size,
max_out_len,
......@@ -612,9 +685,10 @@ def fast_decode(
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
dropout_rate, weight_sharing)
enc_output = wrap_encoder(
src_vocab_size, max_in_len, n_layer, n_head, d_key, d_value, d_model,
d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout,
preprocess_cmd, postprocess_cmd, weight_sharing)
start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
......@@ -679,7 +753,11 @@ def fast_decode(
d_value,
d_model,
d_inner_hid,
dropout_rate,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
dec_inputs=(
pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
......
......@@ -454,7 +454,9 @@ def train(args):
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.d_inner_hid, ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps)
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册