未验证 提交 31060483 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #819 from guoshengCS/refine-transformer-logit

Avoid predicting <pad> by restricting the size of  fc_layer in Transformer
......@@ -43,21 +43,16 @@ class InferTaskConfig(object):
class ModelHyperParams(object):
# Dictionary size for source and target language. This model directly uses
# paddle.dataset.wmt16 in which <bos>, <eos> and <unk> token has
# alreay been added, but the <pad> token is not added. Transformer requires
# sequences in a mini-batch are padded to have the same length. A <pad> token is
# added into the original dictionary in paddle.dateset.wmt16.
# This model directly uses paddle.dataset.wmt16 in which <bos>, <eos> and
# <unk> token has alreay been added. As for the <pad> token, any token
# included in dict can be used to pad, since the paddings' loss will be
# masked out and make no effect on parameter gradients.
# size of source word dictionary.
src_vocab_size = 10000
# index for <pad> token in source language.
src_pad_idx = src_vocab_size
# size of target word dictionay
trg_vocab_size = 10000
# index for <pad> token in target language.
trg_pad_idx = trg_vocab_size
# index for <bos> token
bos_idx = 0
......@@ -66,11 +61,10 @@ class ModelHyperParams(object):
# index for <unk> token
unk_idx = 2
# position value corresponding to the <pad> token.
pos_pad_idx = 0
# max length of sequences. It should plus 1 to include position
# padding token for position encoding.
# max length of sequences.
# The size of position encoding table should at least plus 1, since the
# sinusoid position encoding starts from 1 and 0 can be used as the padding
# token for position encoding.
max_length = 50
# the dimension for word embeddings, which is also the last dimension of
......
......@@ -41,7 +41,7 @@ def translate_batch(exe,
src_pad_idx,
n_head,
is_target=False,
return_pos=True,
is_label=False,
return_attn_bias=True,
return_max_len=False)
# Append the data shape input to reshape the output of embedding layer.
......@@ -250,22 +250,20 @@ def main():
encoder_program = fluid.Program()
with fluid.program_guard(main_program=encoder_program):
enc_output = encoder(
ModelHyperParams.src_vocab_size + 1,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx)
ModelHyperParams.src_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout)
decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
predict = decoder(
ModelHyperParams.trg_vocab_size + 1,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout)
# Load model parameters of encoder and decoder separately from the saved
# transformer model.
......@@ -301,9 +299,6 @@ def main():
trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
# Append the <pad> token since the dict provided by dataset.wmt16 does
# not include it.
trg_idx2word[ModelHyperParams.trg_pad_idx] = "<pad>"
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
......@@ -327,19 +322,22 @@ def main():
for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data],
exe,
[item[0] for item in data],
encoder_program,
encoder_input_data_names, [enc_output.name],
encoder_input_data_names,
[enc_output.name],
decoder_program,
decoder_input_data_names, [predict.name],
decoder_input_data_names,
[predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_length,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.d_model,
ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx,
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx,
......
......@@ -199,10 +199,8 @@ def prepare_encoder(src_word,
src_pos,
src_vocab_size,
src_emb_dim,
src_pad_idx,
src_max_len,
dropout_rate=0.,
pos_pad_idx=0,
src_data_shape=None,
pos_enc_param_name=None):
"""Add word embeddings and position encodings.
......@@ -214,12 +212,10 @@ def prepare_encoder(src_word,
src_word_emb = layers.embedding(
src_word,
size=[src_vocab_size, src_emb_dim],
padding_idx=src_pad_idx,
param_attr=fluid.initializer.Normal(0., 1.))
src_pos_enc = layers.embedding(
src_pos,
size=[src_max_len, src_emb_dim],
padding_idx=pos_pad_idx,
param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc
......@@ -480,12 +476,16 @@ def make_inputs(input_data_names,
append_batch_size=False)
input_layers += [slf_attn_post_softmax_shape]
if src_attn_shape_flag:
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
src_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[2],
dtype="int32",
append_batch_size=False)
input_layers += [src_attn_pre_softmax_shape]
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
src_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[4],
......@@ -516,10 +516,7 @@ def transformer(
d_value,
d_model,
d_inner_hid,
dropout_rate,
src_pad_idx,
trg_pad_idx,
pos_pad_idx, ):
dropout_rate, ):
enc_inputs = make_inputs(
encoder_input_data_names,
n_head,
......@@ -543,8 +540,6 @@ def transformer(
d_model,
d_inner_hid,
dropout_rate,
src_pad_idx,
pos_pad_idx,
enc_inputs, )
dec_inputs = make_inputs(
......@@ -570,8 +565,6 @@ def transformer(
d_model,
d_inner_hid,
dropout_rate,
trg_pad_idx,
pos_pad_idx,
dec_inputs,
enc_output, )
......@@ -606,8 +599,6 @@ def wrap_encoder(src_vocab_size,
d_model,
d_inner_hid,
dropout_rate,
src_pad_idx,
pos_pad_idx,
enc_inputs=None):
"""
The wrapper assembles together all needed layers for the encoder.
......@@ -637,10 +628,8 @@ def wrap_encoder(src_vocab_size,
src_pos,
src_vocab_size,
d_model,
src_pad_idx,
max_length,
dropout_rate,
pos_pad_idx,
src_data_shape, )
enc_output = encoder(
enc_input,
......@@ -666,8 +655,6 @@ def wrap_decoder(trg_vocab_size,
d_model,
d_inner_hid,
dropout_rate,
trg_pad_idx,
pos_pad_idx,
dec_inputs=None,
enc_output=None):
"""
......@@ -701,10 +688,8 @@ def wrap_decoder(trg_vocab_size,
trg_pos,
trg_vocab_size,
d_model,
trg_pad_idx,
max_length,
dropout_rate,
pos_pad_idx,
trg_data_shape, )
dec_output = decoder(
dec_input,
......
......@@ -15,7 +15,7 @@ def pad_batch_data(insts,
pad_idx,
n_head,
is_target=False,
return_pos=True,
is_label=False,
return_attn_bias=True,
return_max_len=True):
"""
......@@ -24,14 +24,20 @@ def pad_batch_data(insts,
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, 1])]
if return_pos:
inst_pos = np.array([[
pos_i + 1 if w_i != pad_idx else 0 for pos_i, w_i in enumerate(inst)
] for inst in inst_data])
if is_label: # label weight
inst_weight = np.array(
[[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
range(1, len(inst) + 1) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias:
if is_target:
......@@ -84,9 +90,14 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head,
False, False, False, False)
lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1])
lbl_word, lbl_weight = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False)
input_dict = dict(
zip(input_data_names, [
......@@ -105,13 +116,11 @@ def main():
exe = fluid.Executor(place)
sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size + 1,
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout)
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps, place,
......@@ -145,8 +154,8 @@ def main():
for batch_id, data in enumerate(val_data()):
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
test_sum_cost, test_token_num = exe.run(
test_program,
......@@ -171,10 +180,12 @@ def main():
for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_data()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册