提交 bfd73890 编写于 作者: L LiuChiaChi

add transformer generate square subsequent mask api

上级 6e41143f
......@@ -609,6 +609,13 @@ class TestTransformer(unittest.TestCase):
trans_output = transformer(src, tgt, src_mask, tgt_mask,
memory_mask)
def test_generate_square_subsequent_mask(self):
length = 5
d_model, n_head, dim_feedforward = 8, 4, 64
transformer = Transformer(
d_model, n_head, dim_feedforward=dim_feedforward)
mask = transformer.generate_square_subsequent_mask(length)
if __name__ == "__main__":
unittest.main()
......@@ -24,7 +24,9 @@ __all__ = [
import copy
import collections
import numpy as np
import paddle
from .common import Linear, Dropout
from .norm import LayerNorm
from .. import functional as F
......@@ -1174,3 +1176,18 @@ class Transformer(Layer):
output = self.decoder(
tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
return output
def generate_square_subsequent_mask(self, length):
"""
Generate a square mask for the sequence.
Parameters:
length (int): The length of sequence.
Returns:
Tensor: Generated square mask according to the given length.
"""
return paddle.tensor.triu(
-(paddle.ones(
(length, length), dtype=paddle.get_default_dtype()) * np.inf),
1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册