From ee13a2ab88c1896c2f73ebe7c9c78364b6befd54 Mon Sep 17 00:00:00 2001 From: LiuChiachi <709153940@qq.com> Date: Tue, 29 Sep 2020 10:24:40 +0800 Subject: [PATCH] Add transformer generate square subsequent mask api (#27651) * add transformer generate square subsequent mask api * add dtype for input, update doc, use -np.inf * add dtype for input, update doc, use -np.inf --- .../tests/unittests/test_transformer_api.py | 7 ++++ python/paddle/nn/layer/transformer.py | 38 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_transformer_api.py b/python/paddle/fluid/tests/unittests/test_transformer_api.py index 067d1ea5f7..3133aad0f4 100644 --- a/python/paddle/fluid/tests/unittests/test_transformer_api.py +++ b/python/paddle/fluid/tests/unittests/test_transformer_api.py @@ -609,6 +609,13 @@ class TestTransformer(unittest.TestCase): trans_output = transformer(src, tgt, src_mask, tgt_mask, memory_mask) + def test_generate_square_subsequent_mask(self): + length = 5 + d_model, n_head, dim_feedforward = 8, 4, 64 + transformer = Transformer( + d_model, n_head, dim_feedforward=dim_feedforward) + mask = transformer.generate_square_subsequent_mask(length) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index 4b199d5816..e6df5366d2 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -24,7 +24,9 @@ __all__ = [ import copy import collections +import numpy as np +import paddle from .common import Linear, Dropout from .norm import LayerNorm from .. import functional as F @@ -1174,3 +1176,39 @@ class Transformer(Layer): output = self.decoder( tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask) return output + + def generate_square_subsequent_mask(self, length): + """ + Generate a square mask for the sequence. The mask ensures that the + predictions for position i can depend only on the known outputs at + positions less than i. + + Parameters: + length (int|Tensor): The length of sequence. + + Returns: + Tensor: Generated square mask according to the given length. + + Examples: + .. code-block:: python + + import paddle + from paddle.nn.layer.transformer import Transformer + length = 5 + d_model, n_head, dim_feedforward = 8, 4, 64 + transformer_paddle = Transformer( + d_model, n_head, dim_feedforward=dim_feedforward) + mask = transformer_paddle.generate_square_subsequent_mask(length) + print(mask.numpy()) + + # [[ 0. -inf -inf -inf -inf] + # [ 0. 0. -inf -inf -inf] + # [ 0. 0. 0. -inf -inf] + # [ 0. 0. 0. 0. -inf] + # [ 0. 0. 0. 0. 0.]] + + """ + return paddle.tensor.triu( + (paddle.ones( + (length, length), dtype=paddle.get_default_dtype()) * -np.inf), + 1) -- GitLab