提交 10de2bf3 编写于 作者: G guosheng

Avoid predicting <pad> by restricting the size of the final fc_layer in Transformer.

上级 f14db82d
......@@ -39,9 +39,10 @@ def translate_batch(exe,
enc_in_data = pad_batch_data(
src_words,
src_pad_idx,
eos_idx,
n_head,
is_target=False,
return_pos=True,
is_label=False,
return_attn_bias=True,
return_max_len=False)
# Append the data shape input to reshape the output of embedding layer.
......
......@@ -724,10 +724,11 @@ def wrap_decoder(trg_vocab_size,
src_attn_post_softmax_shape, )
# Return logits for training and probs for inference.
predict = layers.reshape(
x=layers.fc(input=dec_output,
size=trg_vocab_size,
x=layers.fc(
input=dec_output,
size=trg_vocab_size - 1, # To exclude <pad>.
bias_attr=False,
num_flatten_dims=2),
shape=[-1, trg_vocab_size],
shape=[-1, trg_vocab_size - 1],
act="softmax" if dec_inputs is None else None)
return predict
......@@ -13,9 +13,10 @@ from config import TrainTaskConfig, ModelHyperParams, pos_enc_param_names, \
def pad_batch_data(insts,
pad_idx,
eos_idx,
n_head,
is_target=False,
return_pos=True,
is_label=False,
return_attn_bias=True,
return_max_len=True):
"""
......@@ -24,14 +25,22 @@ def pad_batch_data(insts,
"""
return_list = []
max_len = max(len(inst) for inst in insts)
inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
# Since we restrict the predicted probs excluding the <pad> to avoid
# generating the <pad>, also replace the <pad> with others in labels.
inst_data = np.array([
inst + [eos_idx if is_label else pad_idx] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_data.astype("int64").reshape([-1, 1])]
if return_pos:
inst_pos = np.array([[
pos_i + 1 if w_i != pad_idx else 0 for pos_i, w_i in enumerate(inst)
] for inst in inst_data])
if is_label: # label weight
inst_weight = np.array(
[[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
range(1, len(inst) + 1) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias:
if is_target:
......@@ -57,14 +66,22 @@ def pad_batch_data(insts,
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
n_head, d_model):
eos_idx, n_head, d_model):
"""
Put all padded data needed by training into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
[inst[0] for inst in insts],
src_pad_idx,
eos_idx,
n_head,
is_target=False)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
[inst[1] for inst in insts],
trg_pad_idx,
eos_idx,
n_head,
is_target=True)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
......@@ -84,9 +101,15 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head,
False, False, False, False)
lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1])
lbl_word, lbl_weight = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
eos_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False)
input_dict = dict(
zip(input_data_names, [
......@@ -146,8 +169,8 @@ def main():
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
ModelHyperParams.trg_pad_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
test_sum_cost, test_token_num = exe.run(
test_program,
feed=data_input,
......@@ -174,8 +197,8 @@ def main():
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
ModelHyperParams.trg_pad_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(),
feed=data_input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册