提交 a6ec3a0d 编写于 作者: G guosheng

Tune the Transformer model for wmt14

上级 f93838a4
class TrainTaskConfig(object): class TrainTaskConfig(object):
use_gpu = True use_gpu = True
# the epoch number to train. # the epoch number to train.
pass_num = 30 pass_num = 200
# the number of sequences contained in a mini-batch. # the number of sequences contained in a mini-batch.
batch_size = 32 batch_size = 32
# the hyper parameters for Adam optimizer. # the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler # This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate. # derived learning rate the to get the final learning rate.
learning_rate = 1 learning_rate = 2
beta1 = 0.9 beta1 = 0.9
beta2 = 0.98 beta2 = 0.997
eps = 1e-9 eps = 1e-9
# the parameters for learning rate scheduling. # the parameters for learning rate scheduling.
warmup_steps = 4000 warmup_steps = 8000
# the flag indicating to use average loss or sum loss when training. # the flag indicating to use average loss or sum loss when training.
use_avg_cost = True use_avg_cost = True
# the weight used to mix up the ground-truth distribution and the fixed # the weight used to mix up the ground-truth distribution and the fixed
...@@ -33,12 +33,12 @@ class TrainTaskConfig(object): ...@@ -33,12 +33,12 @@ class TrainTaskConfig(object):
class InferTaskConfig(object): class InferTaskConfig(object):
use_gpu = True use_gpu = False
# the number of examples in one run for sequence generation. # the number of examples in one run for sequence generation.
batch_size = 10 batch_size = 2
# the parameters for beam search. # the parameters for beam search.
beam_size = 5 beam_size = 5
max_length = 30 max_out_len = 30
# the number of decoded sentences to output. # the number of decoded sentences to output.
n_best = 1 n_best = 1
# the flags indicating whether to output the special tokens. # the flags indicating whether to output the special tokens.
...@@ -55,26 +55,26 @@ class ModelHyperParams(object): ...@@ -55,26 +55,26 @@ class ModelHyperParams(object):
# included in dict can be used to pad, since the paddings' loss will be # included in dict can be used to pad, since the paddings' loss will be
# masked out and make no effect on parameter gradients. # masked out and make no effect on parameter gradients.
# size of source word dictionary. # size of source word dictionary.
src_vocab_size = 10000 src_vocab_size = 50000
# size of target word dictionay # size of target word dictionay
trg_vocab_size = 10000 trg_vocab_size = 50000
# index for <bos> token # index for <bos> token
bos_idx = 0 bos_idx = 1
# index for <eos> token # index for <eos> token
eos_idx = 1 eos_idx = 2
# index for <unk> token # index for <unk> token
unk_idx = 2 unk_idx = 0
# max length of sequences. # max length of sequences.
# The size of position encoding table should at least plus 1, since the # The size of position encoding table should at least plus 1, since the
# sinusoid position encoding starts from 1 and 0 can be used as the padding # sinusoid position encoding starts from 1 and 0 can be used as the padding
# token for position encoding. # token for position encoding.
max_length = 50 max_length = 256
# the dimension for word embeddings, which is also the last dimension of # the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward # the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder. # networks, encoder and decoder.
d_model = 512 d_model = 512
# size of the hidden layer in position-wise feed-forward networks. # size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 1024 d_inner_hid = 2048
# the dimension that keys are projected to for dot-product attention. # the dimension that keys are projected to for dot-product attention.
d_key = 64 d_key = 64
# the dimension that values are projected to for dot-product attention. # the dimension that values are projected to for dot-product attention.
...@@ -89,7 +89,7 @@ class ModelHyperParams(object): ...@@ -89,7 +89,7 @@ class ModelHyperParams(object):
def merge_cfg_from_list(cfg_list, g_cfgs): def merge_cfg_from_list(cfg_list, g_cfgs):
""" """
Set the above global configurations using the cfg_list. Set the above global configurations using the cfg_list.
""" """
assert len(cfg_list) % 2 == 0 assert len(cfg_list) % 2 == 0
for key, value in zip(cfg_list[0::2], cfg_list[1::2]): for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
...@@ -103,23 +103,28 @@ def merge_cfg_from_list(cfg_list, g_cfgs): ...@@ -103,23 +103,28 @@ def merge_cfg_from_list(cfg_list, g_cfgs):
break break
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs. # Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in # The shapes here act as placeholder and are set to pass the infer-shape in
# compile time. # compile time.
input_descs = { input_descs = {
# The actual data shape of src_word is: # The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1] # [batch_size * max_src_len_in_batch, 1]
"src_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], "src_word": [(batch_size * seq_len, 1L), "int64", 2],
# The actual data shape of src_pos is: # The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1] # [batch_size * max_src_len_in_batch, 1]
"src_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], "src_pos": [(batch_size * seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings in the # This input is used to remove attention weights on paddings in the
# encoder. # encoder.
# The actual data shape of src_slf_attn_bias is: # The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": "src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
[(1, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1), seq_len), "float32"],
(ModelHyperParams.max_length + 1)), "float32"],
# This shape input is used to reshape the output of embedding layer. # This shape input is used to reshape the output of embedding layer.
"src_data_shape": [(3L, ), "int32"], "src_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention. # This shape input is used to reshape before softmax in self attention.
...@@ -128,24 +133,23 @@ input_descs = { ...@@ -128,24 +133,23 @@ input_descs = {
"src_slf_attn_post_softmax_shape": [(4L, ), "int32"], "src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# The actual data shape of trg_word is: # The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"trg_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], "trg_word": [(batch_size * seq_len, 1L), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is: # The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"trg_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], "trg_pos": [(batch_size * seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings and # This input is used to remove attention weights on paddings and
# subsequent words in the decoder. # subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is: # The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch] # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(1, ModelHyperParams.n_head, "trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
(ModelHyperParams.max_length + 1), seq_len), "float32"],
(ModelHyperParams.max_length + 1)), "float32"],
# This input is used to remove attention weights on paddings of the source # This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention. # input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is: # The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(1, ModelHyperParams.n_head, "trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
(ModelHyperParams.max_length + 1), seq_len), "float32"],
(ModelHyperParams.max_length + 1)), "float32"],
# This shape input is used to reshape the output of embedding layer. # This shape input is used to reshape the output of embedding layer.
"trg_data_shape": [(3L, ), "int32"], "trg_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention. # This shape input is used to reshape before softmax in self attention.
...@@ -161,17 +165,23 @@ input_descs = { ...@@ -161,17 +165,23 @@ input_descs = {
# This input is used in independent decoder program for inference. # This input is used in independent decoder program for inference.
# The actual data shape of enc_output is: # The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model] # [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(1, (ModelHyperParams.max_length + 1), "enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is: # The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], "lbl_word": [(batch_size * seq_len, 1L), "int64"],
# This input is used to mask out the loss of paddding tokens. # This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is: # The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"], "lbl_weight": [(batch_size * seq_len, 1L), "float32"],
# These inputs are used to change the shape tensor in beam-search decoder.
"trg_slf_attn_pre_softmax_shape_delta": [(2L, ), "int32"],
"trg_slf_attn_post_softmax_shape_delta": [(4L, ), "int32"],
"init_score": [(batch_size, 1L), "float32"],
} }
word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table", )
# Names of position encoding table which will be initialized externally. # Names of position encoding table which will be initialized externally.
pos_enc_param_names = ( pos_enc_param_names = (
"src_pos_enc_table", "src_pos_enc_table",
...@@ -200,3 +210,12 @@ decoder_util_input_fields = ( ...@@ -200,3 +210,12 @@ decoder_util_input_fields = (
label_data_input_fields = ( label_data_input_fields = (
"lbl_word", "lbl_word",
"lbl_weight", ) "lbl_weight", )
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
"trg_word",
"init_score",
"trg_src_attn_bias", )
fast_decoder_util_input_fields = decoder_util_input_fields + (
"trg_slf_attn_pre_softmax_shape_delta",
"trg_slf_attn_post_softmax_shape_delta", )
...@@ -6,6 +6,8 @@ import paddle.fluid.layers as layers ...@@ -6,6 +6,8 @@ import paddle.fluid.layers as layers
from config import * from config import *
WEIGHT_SHARING = True
def position_encoding_init(n_position, d_pos_vec): def position_encoding_init(n_position, d_pos_vec):
""" """
...@@ -30,7 +32,8 @@ def multi_head_attention(queries, ...@@ -30,7 +32,8 @@ def multi_head_attention(queries,
n_head=1, n_head=1,
dropout_rate=0., dropout_rate=0.,
pre_softmax_shape=None, pre_softmax_shape=None,
post_softmax_shape=None): post_softmax_shape=None,
cache=None):
""" """
Multi-Head Attention. Note that attn_bias is added to the logit before Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that computing softmax activiation to mask certain selected positions so that
...@@ -44,30 +47,30 @@ def multi_head_attention(queries, ...@@ -44,30 +47,30 @@ def multi_head_attention(queries,
""" """
Add linear projection to queries, keys, and values. Add linear projection to queries, keys, and values.
""" """
q = layers.fc(input=queries, q = layers.fc(
size=d_key * n_head, input=queries,
param_attr=fluid.initializer.Xavier( size=d_key * n_head,
uniform=False, param_attr=fluid.initializer.Xavier(uniform=True),
fan_in=d_model * d_key, # fan_in=d_model * d_key,
fan_out=n_head * d_key), # fan_out=n_head * d_key),
bias_attr=False, bias_attr=False,
num_flatten_dims=2) num_flatten_dims=2)
k = layers.fc(input=keys, k = layers.fc(
size=d_key * n_head, input=keys,
param_attr=fluid.initializer.Xavier( size=d_key * n_head,
uniform=False, param_attr=fluid.initializer.Xavier(uniform=True),
fan_in=d_model * d_key, # fan_in=d_model * d_key,
fan_out=n_head * d_key), # fan_out=n_head * d_key),
bias_attr=False, bias_attr=False,
num_flatten_dims=2) num_flatten_dims=2)
v = layers.fc(input=values, v = layers.fc(
size=d_value * n_head, input=values,
param_attr=fluid.initializer.Xavier( size=d_value * n_head,
uniform=False, param_attr=fluid.initializer.Xavier(uniform=True),
fan_in=d_model * d_value, # fan_in=d_model * d_value,
fan_out=n_head * d_value), # fan_out=n_head * d_value),
bias_attr=False, bias_attr=False,
num_flatten_dims=2) num_flatten_dims=2)
return q, k, v return q, k, v
def __split_heads(x, n_head): def __split_heads(x, n_head):
...@@ -84,7 +87,7 @@ def multi_head_attention(queries, ...@@ -84,7 +87,7 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension # The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size. # size of the input as the output dimension size.
reshaped = layers.reshape( reshaped = layers.reshape(
x=x, shape=[0, -1, n_head, hidden_size // n_head]) x=x, shape=[0, 0, n_head, hidden_size // n_head])
# permuate the dimensions into: # permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head] # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
...@@ -104,13 +107,13 @@ def multi_head_attention(queries, ...@@ -104,13 +107,13 @@ def multi_head_attention(queries,
# size of the input as the output dimension size. # size of the input as the output dimension size.
return layers.reshape( return layers.reshape(
x=trans_x, x=trans_x,
shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]])) shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate): def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
""" """
Scaled Dot-Product Attention Scaled Dot-Product Attention
""" """
scaled_q = layers.scale(x=q, scale=d_model**-0.5) scaled_q = layers.scale(x=q, scale=d_key**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True) product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
weights = layers.reshape( weights = layers.reshape(
x=layers.elementwise_add( x=layers.elementwise_add(
...@@ -123,11 +126,15 @@ def multi_head_attention(queries, ...@@ -123,11 +126,15 @@ def multi_head_attention(queries,
if dropout_rate: if dropout_rate:
weights = layers.dropout( weights = layers.dropout(
weights, dropout_prob=dropout_rate, is_test=False) weights, dropout_prob=dropout_rate, is_test=False)
out = layers.matmul(weights, v) out = layers.matmul(weights, v)
return out return out
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value) q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
k = cache["k"] = layers.concat([cache["k"], k], axis=1)
v = cache["v"] = layers.concat([cache["v"], v], axis=1)
q = __split_heads(q, n_head) q = __split_heads(q, n_head)
k = __split_heads(k, n_head) k = __split_heads(k, n_head)
v = __split_heads(v, n_head) v = __split_heads(v, n_head)
...@@ -136,7 +143,6 @@ def multi_head_attention(queries, ...@@ -136,7 +143,6 @@ def multi_head_attention(queries,
dropout_rate) dropout_rate)
out = __combine_heads(ctx_multiheads) out = __combine_heads(ctx_multiheads)
# Project back to the model size. # Project back to the model size.
proj_out = layers.fc(input=out, proj_out = layers.fc(input=out,
size=d_model, size=d_model,
...@@ -146,23 +152,32 @@ def multi_head_attention(queries, ...@@ -146,23 +152,32 @@ def multi_head_attention(queries,
return proj_out return proj_out
def positionwise_feed_forward(x, d_inner_hid, d_hid): def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate=0.):
""" """
Position-wise Feed-Forward Networks. Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically. in between, which is applied to each position separately and identically.
""" """
hidden = layers.fc(input=x, hidden = layers.fc(
size=d_inner_hid, input=x,
num_flatten_dims=2, size=d_inner_hid,
param_attr=fluid.initializer.Uniform( num_flatten_dims=2,
low=-(d_hid**-0.5), high=(d_hid**-0.5)), param_attr=fluid.initializer.Xavier(uniform=True),
act="relu") #param_attr=fluid.initializer.Uniform(
out = layers.fc(input=hidden, # low=-(d_hid**-0.5), high=(d_hid**-0.5)),
size=d_hid, bias_attr=True,
num_flatten_dims=2, act="relu")
param_attr=fluid.initializer.Uniform( if dropout_rate:
low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5))) hidden = layers.dropout(
hidden, dropout_prob=dropout_rate, is_test=False)
out = layers.fc(
input=hidden,
size=d_hid,
num_flatten_dims=2,
param_attr=fluid.initializer.Xavier(uniform=True),
#param_attr=fluid.initializer.Uniform(
# low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)),
bias_attr=True)
return out return out
...@@ -200,6 +215,7 @@ def prepare_encoder(src_word, ...@@ -200,6 +215,7 @@ def prepare_encoder(src_word,
src_max_len, src_max_len,
dropout_rate=0., dropout_rate=0.,
src_data_shape=None, src_data_shape=None,
word_emb_param_name=None,
pos_enc_param_name=None): pos_enc_param_name=None):
"""Add word embeddings and position encodings. """Add word embeddings and position encodings.
The output tensor has a shape of: The output tensor has a shape of:
...@@ -209,7 +225,10 @@ def prepare_encoder(src_word, ...@@ -209,7 +225,10 @@ def prepare_encoder(src_word,
src_word_emb = layers.embedding( src_word_emb = layers.embedding(
src_word, src_word,
size=[src_vocab_size, src_emb_dim], size=[src_vocab_size, src_emb_dim],
param_attr=fluid.initializer.Normal(0., 1.)) param_attr=fluid.ParamAttr(
name=word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
src_pos_enc = layers.embedding( src_pos_enc = layers.embedding(
src_pos, src_pos,
size=[src_max_len, src_emb_dim], size=[src_max_len, src_emb_dim],
...@@ -218,7 +237,7 @@ def prepare_encoder(src_word, ...@@ -218,7 +237,7 @@ def prepare_encoder(src_word,
enc_input = src_word_emb + src_pos_enc enc_input = src_word_emb + src_pos_enc
enc_input = layers.reshape( enc_input = layers.reshape(
x=enc_input, x=enc_input,
shape=[-1, src_max_len, src_emb_dim], shape=[batch_size, seq_len, src_emb_dim],
actual_shape=src_data_shape) actual_shape=src_data_shape)
return layers.dropout( return layers.dropout(
enc_input, dropout_prob=dropout_rate, enc_input, dropout_prob=dropout_rate,
...@@ -226,9 +245,14 @@ def prepare_encoder(src_word, ...@@ -226,9 +245,14 @@ def prepare_encoder(src_word,
prepare_encoder = partial( prepare_encoder = partial(
prepare_encoder, pos_enc_param_name=pos_enc_param_names[0]) prepare_encoder,
word_emb_param_name=word_emb_param_names[0],
pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial( prepare_decoder = partial(
prepare_encoder, pos_enc_param_name=pos_enc_param_names[1]) prepare_encoder,
word_emb_param_name=word_emb_param_names[0]
if WEIGHT_SHARING else word_emb_param_names[1],
pos_enc_param_name=pos_enc_param_names[1])
def encoder_layer(enc_input, def encoder_layer(enc_input,
...@@ -247,13 +271,14 @@ def encoder_layer(enc_input, ...@@ -247,13 +271,14 @@ def encoder_layer(enc_input,
with the post_process_layer to add residual connection, layer normalization with the post_process_layer to add residual connection, layer normalization
and droput. and droput.
""" """
attn_output = multi_head_attention( q = k = v = pre_process_layer(enc_input, "n")
enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model, attn_output = multi_head_attention(q, k, v, attn_bias, d_key, d_value,
n_head, dropout_rate, pre_softmax_shape, post_softmax_shape) d_model, n_head, dropout_rate,
attn_output = post_process_layer(enc_input, attn_output, "dan", pre_softmax_shape, post_softmax_shape)
dropout_rate) attn_output = post_process_layer(enc_input, attn_output, "da", dropout_rate)
ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model) ffd_output = positionwise_feed_forward(
return post_process_layer(attn_output, ffd_output, "dan", dropout_rate) pre_process_layer(attn_output, "n"), d_inner_hid, d_model, dropout_rate)
return post_process_layer(attn_output, ffd_output, "da", dropout_rate)
def encoder(enc_input, def encoder(enc_input,
...@@ -284,6 +309,7 @@ def encoder(enc_input, ...@@ -284,6 +309,7 @@ def encoder(enc_input,
pre_softmax_shape, pre_softmax_shape,
post_softmax_shape, ) post_softmax_shape, )
enc_input = enc_output enc_input = enc_output
enc_output = pre_process_layer(enc_output, "n")
return enc_output return enc_output
...@@ -300,15 +326,17 @@ def decoder_layer(dec_input, ...@@ -300,15 +326,17 @@ def decoder_layer(dec_input,
slf_attn_pre_softmax_shape=None, slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None, slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None, src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None): src_attn_post_softmax_shape=None,
cache=None):
""" The layer to be stacked in decoder part. """ The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention. a multi-head attention is added to implement encoder-decoder attention.
""" """
q = k = v = pre_process_layer(dec_input, "n")
slf_attn_output = multi_head_attention( slf_attn_output = multi_head_attention(
dec_input, q,
dec_input, k,
dec_input, v,
slf_attn_bias, slf_attn_bias,
d_key, d_key,
d_value, d_value,
...@@ -316,14 +344,15 @@ def decoder_layer(dec_input, ...@@ -316,14 +344,15 @@ def decoder_layer(dec_input,
n_head, n_head,
dropout_rate, dropout_rate,
slf_attn_pre_softmax_shape, slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, ) slf_attn_post_softmax_shape,
cache, )
slf_attn_output = post_process_layer( slf_attn_output = post_process_layer(
dec_input, dec_input,
slf_attn_output, slf_attn_output,
"dan", # residual connection + dropout + layer normalization "da", # residual connection + dropout + layer normalization
dropout_rate, ) dropout_rate, )
enc_attn_output = multi_head_attention( enc_attn_output = multi_head_attention(
slf_attn_output, pre_process_layer(slf_attn_output, "n"),
enc_output, enc_output,
enc_output, enc_output,
dec_enc_attn_bias, dec_enc_attn_bias,
...@@ -337,16 +366,17 @@ def decoder_layer(dec_input, ...@@ -337,16 +366,17 @@ def decoder_layer(dec_input,
enc_attn_output = post_process_layer( enc_attn_output = post_process_layer(
slf_attn_output, slf_attn_output,
enc_attn_output, enc_attn_output,
"dan", # residual connection + dropout + layer normalization "da", # residual connection + dropout + layer normalization
dropout_rate, ) dropout_rate, )
ffd_output = positionwise_feed_forward( ffd_output = positionwise_feed_forward(
enc_attn_output, pre_process_layer(enc_attn_output, "n"),
d_inner_hid, d_inner_hid,
d_model, ) d_model,
dropout_rate, )
dec_output = post_process_layer( dec_output = post_process_layer(
enc_attn_output, enc_attn_output,
ffd_output, ffd_output,
"dan", # residual connection + dropout + layer normalization "da", # residual connection + dropout + layer normalization
dropout_rate, ) dropout_rate, )
return dec_output return dec_output
...@@ -365,27 +395,20 @@ def decoder(dec_input, ...@@ -365,27 +395,20 @@ def decoder(dec_input,
slf_attn_pre_softmax_shape=None, slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None, slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None, src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None): src_attn_post_softmax_shape=None,
caches=None):
""" """
The decoder is composed of a stack of identical decoder_layer layers. The decoder is composed of a stack of identical decoder_layer layers.
""" """
for i in range(n_layer): for i in range(n_layer):
dec_output = decoder_layer( dec_output = decoder_layer(
dec_input, dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias, n_head,
enc_output, d_key, d_value, d_model, d_inner_hid, dropout_rate,
dec_slf_attn_bias, slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
dec_enc_attn_bias, src_attn_pre_softmax_shape, src_attn_post_softmax_shape, None
n_head, if caches is None else caches[i])
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
dec_input = dec_output dec_input = dec_output
dec_output = pre_process_layer(dec_output, "n")
return dec_output return dec_output
...@@ -399,6 +422,8 @@ def make_all_inputs(input_fields): ...@@ -399,6 +422,8 @@ def make_all_inputs(input_fields):
name=input_field, name=input_field,
shape=input_descs[input_field][0], shape=input_descs[input_field][0],
dtype=input_descs[input_field][1], dtype=input_descs[input_field][1],
lod_level=input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0,
append_batch_size=False) append_batch_size=False)
inputs.append(input_var) inputs.append(input_var)
return inputs return inputs
...@@ -459,7 +484,6 @@ def transformer( ...@@ -459,7 +484,6 @@ def transformer(
logits=predict, logits=predict,
label=label, label=label,
soft_label=True if label_smooth_eps else False) soft_label=True if label_smooth_eps else False)
# cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
weighted_cost = cost * weights weighted_cost = cost * weights
sum_cost = layers.reduce_sum(weighted_cost) sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights) token_num = layers.reduce_sum(weights)
...@@ -523,7 +547,8 @@ def wrap_decoder(trg_vocab_size, ...@@ -523,7 +547,8 @@ def wrap_decoder(trg_vocab_size,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
dec_inputs=None, dec_inputs=None,
enc_output=None): enc_output=None,
caches=None):
""" """
The wrapper assembles together all needed layers for the decoder. The wrapper assembles together all needed layers for the decoder.
""" """
...@@ -563,13 +588,23 @@ def wrap_decoder(trg_vocab_size, ...@@ -563,13 +588,23 @@ def wrap_decoder(trg_vocab_size,
slf_attn_pre_softmax_shape, slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, ) src_attn_post_softmax_shape,
caches, )
# Return logits for training and probs for inference. # Return logits for training and probs for inference.
predict = layers.reshape( if not WEIGHT_SHARING:
x=layers.fc(input=dec_output, predict = layers.reshape(
size=trg_vocab_size, x=layers.fc(input=dec_output,
bias_attr=False, size=trg_vocab_size,
num_flatten_dims=2), bias_attr=False,
shape=[-1, trg_vocab_size], num_flatten_dims=2),
act="softmax" if dec_inputs is None else None) shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None)
else:
predict = layers.reshape(
x=layers.matmul(
x=dec_output,
y=fluid.get_var(word_emb_param_names[0]),
transpose_y=True),
shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None)
return predict return predict
...@@ -288,6 +288,7 @@ def train(args): ...@@ -288,6 +288,7 @@ def train(args):
start_mark=args.special_token[0], start_mark=args.special_token[0],
end_mark=args.special_token[1], end_mark=args.special_token[1],
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False) clip_last_batch=False)
train_data = read_multiple(reader=train_data.batch_generator) train_data = read_multiple(reader=train_data.batch_generator)
...@@ -319,6 +320,7 @@ def train(args): ...@@ -319,6 +320,7 @@ def train(args):
start_mark=args.special_token[0], start_mark=args.special_token[0],
end_mark=args.special_token[1], end_mark=args.special_token[1],
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False, clip_last_batch=False,
shuffle=False, shuffle=False,
shuffle_batch=False) shuffle_batch=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册