diff --git a/python/paddle/fluid/tests/unittests/test_transformer_api.py b/python/paddle/fluid/tests/unittests/test_transformer_api.py index 067d1ea5f73bf7d7af8a3511fa18dbc38b148656..3133aad0f485363583610ad0b7ee5d0e80ed2146 100644 --- a/python/paddle/fluid/tests/unittests/test_transformer_api.py +++ b/python/paddle/fluid/tests/unittests/test_transformer_api.py @@ -609,6 +609,13 @@ class TestTransformer(unittest.TestCase): trans_output = transformer(src, tgt, src_mask, tgt_mask, memory_mask) + def test_generate_square_subsequent_mask(self): + length = 5 + d_model, n_head, dim_feedforward = 8, 4, 64 + transformer = Transformer( + d_model, n_head, dim_feedforward=dim_feedforward) + mask = transformer.generate_square_subsequent_mask(length) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index 4b199d5816c808d4975c51bc154ad21d46f135eb..e6df5366d216cfc8d1b019057a12635360f77687 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -24,7 +24,9 @@ __all__ = [ import copy import collections +import numpy as np +import paddle from .common import Linear, Dropout from .norm import LayerNorm from .. import functional as F @@ -1174,3 +1176,39 @@ class Transformer(Layer): output = self.decoder( tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask) return output + + def generate_square_subsequent_mask(self, length): + """ + Generate a square mask for the sequence. The mask ensures that the + predictions for position i can depend only on the known outputs at + positions less than i. + + Parameters: + length (int|Tensor): The length of sequence. + + Returns: + Tensor: Generated square mask according to the given length. + + Examples: + .. code-block:: python + + import paddle + from paddle.nn.layer.transformer import Transformer + length = 5 + d_model, n_head, dim_feedforward = 8, 4, 64 + transformer_paddle = Transformer( + d_model, n_head, dim_feedforward=dim_feedforward) + mask = transformer_paddle.generate_square_subsequent_mask(length) + print(mask.numpy()) + + # [[ 0. -inf -inf -inf -inf] + # [ 0. 0. -inf -inf -inf] + # [ 0. 0. 0. -inf -inf] + # [ 0. 0. 0. 0. -inf] + # [ 0. 0. 0. 0. 0.]] + + """ + return paddle.tensor.triu( + (paddle.ones( + (length, length), dtype=paddle.get_default_dtype()) * -np.inf), + 1)