提交 bfd73890 编写于 作者: L LiuChiaChi

add transformer generate square subsequent mask api

上级 6e41143f
...@@ -609,6 +609,13 @@ class TestTransformer(unittest.TestCase): ...@@ -609,6 +609,13 @@ class TestTransformer(unittest.TestCase):
trans_output = transformer(src, tgt, src_mask, tgt_mask, trans_output = transformer(src, tgt, src_mask, tgt_mask,
memory_mask) memory_mask)
def test_generate_square_subsequent_mask(self):
length = 5
d_model, n_head, dim_feedforward = 8, 4, 64
transformer = Transformer(
d_model, n_head, dim_feedforward=dim_feedforward)
mask = transformer.generate_square_subsequent_mask(length)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -24,7 +24,9 @@ __all__ = [ ...@@ -24,7 +24,9 @@ __all__ = [
import copy import copy
import collections import collections
import numpy as np
import paddle
from .common import Linear, Dropout from .common import Linear, Dropout
from .norm import LayerNorm from .norm import LayerNorm
from .. import functional as F from .. import functional as F
...@@ -1174,3 +1176,18 @@ class Transformer(Layer): ...@@ -1174,3 +1176,18 @@ class Transformer(Layer):
output = self.decoder( output = self.decoder(
tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask) tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
return output return output
def generate_square_subsequent_mask(self, length):
"""
Generate a square mask for the sequence.
Parameters:
length (int): The length of sequence.
Returns:
Tensor: Generated square mask according to the given length.
"""
return paddle.tensor.triu(
-(paddle.ones(
(length, length), dtype=paddle.get_default_dtype()) * np.inf),
1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册