未验证 提交 ee13a2ab 编写于 作者: L LiuChiachi 提交者: GitHub

Add transformer generate square subsequent mask api (#27651)

* add transformer generate square subsequent mask api

* add dtype for input, update doc, use -np.inf

* add dtype for input, update doc, use -np.inf
上级 29f49229
...@@ -609,6 +609,13 @@ class TestTransformer(unittest.TestCase): ...@@ -609,6 +609,13 @@ class TestTransformer(unittest.TestCase):
trans_output = transformer(src, tgt, src_mask, tgt_mask, trans_output = transformer(src, tgt, src_mask, tgt_mask,
memory_mask) memory_mask)
def test_generate_square_subsequent_mask(self):
length = 5
d_model, n_head, dim_feedforward = 8, 4, 64
transformer = Transformer(
d_model, n_head, dim_feedforward=dim_feedforward)
mask = transformer.generate_square_subsequent_mask(length)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -24,7 +24,9 @@ __all__ = [ ...@@ -24,7 +24,9 @@ __all__ = [
import copy import copy
import collections import collections
import numpy as np
import paddle
from .common import Linear, Dropout from .common import Linear, Dropout
from .norm import LayerNorm from .norm import LayerNorm
from .. import functional as F from .. import functional as F
...@@ -1174,3 +1176,39 @@ class Transformer(Layer): ...@@ -1174,3 +1176,39 @@ class Transformer(Layer):
output = self.decoder( output = self.decoder(
tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask) tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
return output return output
def generate_square_subsequent_mask(self, length):
"""
Generate a square mask for the sequence. The mask ensures that the
predictions for position i can depend only on the known outputs at
positions less than i.
Parameters:
length (int|Tensor): The length of sequence.
Returns:
Tensor: Generated square mask according to the given length.
Examples:
.. code-block:: python
import paddle
from paddle.nn.layer.transformer import Transformer
length = 5
d_model, n_head, dim_feedforward = 8, 4, 64
transformer_paddle = Transformer(
d_model, n_head, dim_feedforward=dim_feedforward)
mask = transformer_paddle.generate_square_subsequent_mask(length)
print(mask.numpy())
# [[ 0. -inf -inf -inf -inf]
# [ 0. 0. -inf -inf -inf]
# [ 0. 0. 0. -inf -inf]
# [ 0. 0. 0. 0. -inf]
# [ 0. 0. 0. 0. 0.]]
"""
return paddle.tensor.triu(
(paddle.ones(
(length, length), dtype=paddle.get_default_dtype()) * -np.inf),
1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册