提交 10de2bf3 编写于 作者: G guosheng

Avoid predicting <pad> by restricting the size of the final fc_layer in Transformer.

上级 f14db82d
...@@ -39,9 +39,10 @@ def translate_batch(exe, ...@@ -39,9 +39,10 @@ def translate_batch(exe,
enc_in_data = pad_batch_data( enc_in_data = pad_batch_data(
src_words, src_words,
src_pad_idx, src_pad_idx,
eos_idx,
n_head, n_head,
is_target=False, is_target=False,
return_pos=True, is_label=False,
return_attn_bias=True, return_attn_bias=True,
return_max_len=False) return_max_len=False)
# Append the data shape input to reshape the output of embedding layer. # Append the data shape input to reshape the output of embedding layer.
......
...@@ -724,10 +724,11 @@ def wrap_decoder(trg_vocab_size, ...@@ -724,10 +724,11 @@ def wrap_decoder(trg_vocab_size,
src_attn_post_softmax_shape, ) src_attn_post_softmax_shape, )
# Return logits for training and probs for inference. # Return logits for training and probs for inference.
predict = layers.reshape( predict = layers.reshape(
x=layers.fc(input=dec_output, x=layers.fc(
size=trg_vocab_size, input=dec_output,
size=trg_vocab_size - 1, # To exclude <pad>.
bias_attr=False, bias_attr=False,
num_flatten_dims=2), num_flatten_dims=2),
shape=[-1, trg_vocab_size], shape=[-1, trg_vocab_size - 1],
act="softmax" if dec_inputs is None else None) act="softmax" if dec_inputs is None else None)
return predict return predict
...@@ -13,9 +13,10 @@ from config import TrainTaskConfig, ModelHyperParams, pos_enc_param_names, \ ...@@ -13,9 +13,10 @@ from config import TrainTaskConfig, ModelHyperParams, pos_enc_param_names, \
def pad_batch_data(insts, def pad_batch_data(insts,
pad_idx, pad_idx,
eos_idx,
n_head, n_head,
is_target=False, is_target=False,
return_pos=True, is_label=False,
return_attn_bias=True, return_attn_bias=True,
return_max_len=True): return_max_len=True):
""" """
...@@ -24,14 +25,22 @@ def pad_batch_data(insts, ...@@ -24,14 +25,22 @@ def pad_batch_data(insts,
""" """
return_list = [] return_list = []
max_len = max(len(inst) for inst in insts) max_len = max(len(inst) for inst in insts)
inst_data = np.array( # Since we restrict the predicted probs excluding the <pad> to avoid
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts]) # generating the <pad>, also replace the <pad> with others in labels.
inst_data = np.array([
inst + [eos_idx if is_label else pad_idx] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_data.astype("int64").reshape([-1, 1])] return_list += [inst_data.astype("int64").reshape([-1, 1])]
if return_pos: if is_label: # label weight
inst_pos = np.array([[ inst_weight = np.array(
pos_i + 1 if w_i != pad_idx else 0 for pos_i, w_i in enumerate(inst) [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
] for inst in inst_data]) return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
range(1, len(inst) + 1) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])] return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias: if return_attn_bias:
if is_target: if is_target:
...@@ -57,14 +66,22 @@ def pad_batch_data(insts, ...@@ -57,14 +66,22 @@ def pad_batch_data(insts,
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
n_head, d_model): eos_idx, n_head, d_model):
""" """
Put all padded data needed by training into a dict. Put all padded data needed by training into a dict.
""" """
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) [inst[0] for inst in insts],
src_pad_idx,
eos_idx,
n_head,
is_target=False)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) [inst[1] for inst in insts],
trg_pad_idx,
eos_idx,
n_head,
is_target=True)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32") [1, 1, trg_max_len, 1]).astype("float32")
...@@ -84,9 +101,15 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, ...@@ -84,9 +101,15 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
trg_src_attn_post_softmax_shape = np.array( trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32") trg_src_attn_bias.shape, dtype="int32")
lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head, lbl_word, lbl_weight = pad_batch_data(
False, False, False, False) [inst[2] for inst in insts],
lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1]) trg_pad_idx,
eos_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False)
input_dict = dict( input_dict = dict(
zip(input_data_names, [ zip(input_data_names, [
...@@ -146,8 +169,8 @@ def main(): ...@@ -146,8 +169,8 @@ def main():
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] + data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx, label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head, ModelHyperParams.trg_pad_idx, ModelHyperParams.eos_idx,
ModelHyperParams.d_model) ModelHyperParams.n_head, ModelHyperParams.d_model)
test_sum_cost, test_token_num = exe.run( test_sum_cost, test_token_num = exe.run(
test_program, test_program,
feed=data_input, feed=data_input,
...@@ -174,8 +197,8 @@ def main(): ...@@ -174,8 +197,8 @@ def main():
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] + data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx, label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head, ModelHyperParams.trg_pad_idx, ModelHyperParams.eos_idx,
ModelHyperParams.d_model) ModelHyperParams.n_head, ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input) lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(), outs = exe.run(fluid.framework.default_main_program(),
feed=data_input, feed=data_input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册