未验证 提交 ee13a2ab 编写于 作者: L LiuChiachi 提交者: GitHub

Add transformer generate square subsequent mask api (#27651)

* add transformer generate square subsequent mask api

* add dtype for input, update doc, use -np.inf

* add dtype for input, update doc, use -np.inf
上级 29f49229
......@@ -609,6 +609,13 @@ class TestTransformer(unittest.TestCase):
trans_output = transformer(src, tgt, src_mask, tgt_mask,
memory_mask)
def test_generate_square_subsequent_mask(self):
length = 5
d_model, n_head, dim_feedforward = 8, 4, 64
transformer = Transformer(
d_model, n_head, dim_feedforward=dim_feedforward)
mask = transformer.generate_square_subsequent_mask(length)
if __name__ == "__main__":
unittest.main()
......@@ -24,7 +24,9 @@ __all__ = [
import copy
import collections
import numpy as np
import paddle
from .common import Linear, Dropout
from .norm import LayerNorm
from .. import functional as F
......@@ -1174,3 +1176,39 @@ class Transformer(Layer):
output = self.decoder(
tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
return output
def generate_square_subsequent_mask(self, length):
"""
Generate a square mask for the sequence. The mask ensures that the
predictions for position i can depend only on the known outputs at
positions less than i.
Parameters:
length (int|Tensor): The length of sequence.
Returns:
Tensor: Generated square mask according to the given length.
Examples:
.. code-block:: python
import paddle
from paddle.nn.layer.transformer import Transformer
length = 5
d_model, n_head, dim_feedforward = 8, 4, 64
transformer_paddle = Transformer(
d_model, n_head, dim_feedforward=dim_feedforward)
mask = transformer_paddle.generate_square_subsequent_mask(length)
print(mask.numpy())
# [[ 0. -inf -inf -inf -inf]
# [ 0. 0. -inf -inf -inf]
# [ 0. 0. 0. -inf -inf]
# [ 0. 0. 0. 0. -inf]
# [ 0. 0. 0. 0. 0.]]
"""
return paddle.tensor.triu(
(paddle.ones(
(length, length), dtype=paddle.get_default_dtype()) * -np.inf),
1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册