未验证 提交 cb760565 编写于 作者: Z Zhang Ting 提交者: GitHub

set stop_gradient (#5178)

上级 26582d45
......@@ -276,8 +276,10 @@ class TransformerModel(nn.Layer):
src_slf_attn_bias = paddle.cast(
src_word == self.bos_id,
dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9
src_slf_attn_bias.stop_gradient = True
trg_slf_attn_bias = self.transformer.generate_square_subsequent_mask(
trg_max_len)
trg_slf_attn_bias.stop_gradient = True
trg_src_attn_bias = src_slf_attn_bias
src_pos = paddle.cast(
src_word != self.bos_id, dtype="int64") * paddle.arange(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册