diff --git a/python/paddle/fluid/tests/unittests/test_transformer_api.py b/python/paddle/fluid/tests/unittests/test_transformer_api.py index 067d1ea5f73bf7d7af8a3511fa18dbc38b148656..3133aad0f485363583610ad0b7ee5d0e80ed2146 100644 --- a/python/paddle/fluid/tests/unittests/test_transformer_api.py +++ b/python/paddle/fluid/tests/unittests/test_transformer_api.py @@ -609,6 +609,13 @@ class TestTransformer(unittest.TestCase): trans_output = transformer(src, tgt, src_mask, tgt_mask, memory_mask) + def test_generate_square_subsequent_mask(self): + length = 5 + d_model, n_head, dim_feedforward = 8, 4, 64 + transformer = Transformer( + d_model, n_head, dim_feedforward=dim_feedforward) + mask = transformer.generate_square_subsequent_mask(length) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index 4b199d5816c808d4975c51bc154ad21d46f135eb..2f4d573a5b6becd8aa89eb152f76a41481e61a84 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -24,7 +24,9 @@ __all__ = [ import copy import collections +import numpy as np +import paddle from .common import Linear, Dropout from .norm import LayerNorm from .. import functional as F @@ -1174,3 +1176,18 @@ class Transformer(Layer): output = self.decoder( tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask) return output + + def generate_square_subsequent_mask(self, length): + """ + Generate a square mask for the sequence. + + Parameters: + length (int): The length of sequence. + + Returns: + Tensor: Generated square mask according to the given length. + """ + return paddle.tensor.triu( + -(paddle.ones( + (length, length), dtype=paddle.get_default_dtype()) * np.inf), + 1)