提交 9dd1696b 编写于 作者: G guosheng

Remove the losses from paddings in Transformer

上级 7bc7f4e1
......@@ -70,4 +70,5 @@ input_data_names = (
"src_slf_attn_bias",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"lbl_word", )
"lbl_word",
"lbl_weight", )
......@@ -474,4 +474,13 @@ def transformer(
dtype="int64",
append_batch_size=False)
cost = layers.cross_entropy(input=predict, label=gold)
return layers.mean(x=cost)
# The actual shape of weights in runtime is:
# [batch_size * max_trg_length_in_a_batch, 1].
# This is used to remove the losses resulting from paddings.
weights = layers.data(
name=input_data_names[8],
shape=[batch_size * max_length, 1],
dtype="float32",
append_batch_size=False)
weighted_cost = cost * weights
return layers.reduce_sum(weighted_cost)
......@@ -77,17 +77,18 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
[1, 1, trg_max_len, 1]).astype("float32")
lbl_word = __pad_batch_data([inst[2] for inst in insts], trg_pad_idx, False,
False, False, False)
lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1])
data_to_tensor([
src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
], input_data_names, input_dict, place)
return input_dict
def main():
avg_cost = transformer(
cost = transformer(
ModelHyperParams.src_vocab_size + 1,
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
......@@ -101,7 +102,7 @@ def main():
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
optimizer.minimize(avg_cost)
optimizer.minimize(cost)
train_data = paddle.batch(
paddle.reader.shuffle(
......@@ -130,10 +131,10 @@ def main():
ModelHyperParams.n_head, place)
outs = exe.run(fluid.framework.default_main_program(),
feed=data_input,
fetch_list=[avg_cost])
avg_cost_val = np.array(outs[0])
fetch_list=[cost])
cost_val = np.array(outs[0])
print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) +
" avg_cost = " + str(avg_cost_val))
" avg_cost = " + str(cost_val))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册