提交 ecbf5e4a 编写于 作者: L LiuChiaChi

fix bugs about padding_idx

上级 dd3d25ce
......@@ -81,6 +81,7 @@ class Encoder(Layer):
self.hidden_size = hidden_size
self.num_layers = num_layers
self.dtype = dtype
self.padding_idx = padding_idx
self.embedder = Embedding(
vocab_size,
hidden_size,
......@@ -101,7 +102,7 @@ class Encoder(Layer):
outs, (final_h, final_c) = self.lstm(
src_emb, sequence_length=src_sequence_length)
enc_len_mask = (src != 2).astype(self.dtype)
enc_len_mask = (src != self.padding_idx).astype(self.dtype)
enc_padding_mask = (enc_len_mask - 1.0) * 1e9
return [final_h, final_c], outs, enc_padding_mask
......@@ -200,6 +201,7 @@ class Decoder(Layer):
self.num_layers = num_layers
self.init_scale = init_scale
self.dtype = dtype
self.padding_idx = padding_idx
self.embedder = Embedding(
vocab_size,
hidden_size,
......@@ -241,7 +243,7 @@ class Decoder(Layer):
logits=dec_output, label=label, soft_label=False)
loss = paddle.squeeze(loss, axis=[2])
trg_mask = (trg != 2).astype(self.dtype)
trg_mask = (trg != self.padding_idx).astype(self.dtype)
loss = loss * trg_mask
loss = paddle.reduce_mean(loss, dim=[0])
......
/models/dygraph/seq2seq_attn/data
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册