From bfd7389079f7e6105fd79c53657441628e7f2385 Mon Sep 17 00:00:00 2001 From: LiuChiaChi <709153940@qq.com> Date: Mon, 28 Sep 2020 03:15:26 +0000 Subject: [PATCH] add transformer generate square subsequent mask api --- .../tests/unittests/test_transformer_api.py | 7 +++++++ python/paddle/nn/layer/transformer.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_transformer_api.py b/python/paddle/fluid/tests/unittests/test_transformer_api.py index 067d1ea5f73..3133aad0f48 100644 --- a/python/paddle/fluid/tests/unittests/test_transformer_api.py +++ b/python/paddle/fluid/tests/unittests/test_transformer_api.py @@ -609,6 +609,13 @@ class TestTransformer(unittest.TestCase): trans_output = transformer(src, tgt, src_mask, tgt_mask, memory_mask) + def test_generate_square_subsequent_mask(self): + length = 5 + d_model, n_head, dim_feedforward = 8, 4, 64 + transformer = Transformer( + d_model, n_head, dim_feedforward=dim_feedforward) + mask = transformer.generate_square_subsequent_mask(length) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index 4b199d5816c..2f4d573a5b6 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -24,7 +24,9 @@ __all__ = [ import copy import collections +import numpy as np +import paddle from .common import Linear, Dropout from .norm import LayerNorm from .. import functional as F @@ -1174,3 +1176,18 @@ class Transformer(Layer): output = self.decoder( tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask) return output + + def generate_square_subsequent_mask(self, length): + """ + Generate a square mask for the sequence. + + Parameters: + length (int): The length of sequence. + + Returns: + Tensor: Generated square mask according to the given length. + """ + return paddle.tensor.triu( + -(paddle.ones( + (length, length), dtype=paddle.get_default_dtype()) * np.inf), + 1) -- GitLab