embedding.py 4.1 KB
Newer Older
H
Hui Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
小湉湉's avatar
小湉湉 已提交
14
# Modified from espnet(https://github.com/espnet/espnet)
H
Hui Zhang 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
"""Positional Encoding Module."""
import math

import paddle
from paddle import nn


class PositionalEncoding(nn.Layer):
    """Positional encoding.

    Parameters
    ----------
    d_model : int
        Embedding dimension.
    dropout_rate : float
        Dropout rate.
    max_len : int
        Maximum input length.
    reverse : bool
        Whether to reverse the input position.
35 36
    type : str
        dtype of param
H
Hui Zhang 已提交
37 38
    """

39 40 41 42 43 44
    def __init__(self,
                 d_model,
                 dropout_rate,
                 max_len=5000,
                 dtype="float32",
                 reverse=False):
H
Hui Zhang 已提交
45 46 47 48 49 50 51
        """Construct an PositionalEncoding object."""
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        self.reverse = reverse
        self.xscale = math.sqrt(self.d_model)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.pe = None
52 53
        self.dtype = dtype
        self.extend_pe(paddle.expand(paddle.zeros([1]), (1, max_len)))
H
Hui Zhang 已提交
54 55 56

    def extend_pe(self, x):
        """Reset the positional encodings."""
57 58
        x_shape = paddle.shape(x)
        pe = paddle.zeros([x_shape[1], self.d_model])
H
Hui Zhang 已提交
59 60
        if self.reverse:
            position = paddle.arange(
61
                x_shape[1] - 1, -1, -1.0, dtype=self.dtype).unsqueeze(1)
H
Hui Zhang 已提交
62 63
        else:
            position = paddle.arange(
64
                0, x_shape[1], dtype=self.dtype).unsqueeze(1)
H
Hui Zhang 已提交
65
        div_term = paddle.exp(
66
            paddle.arange(0, self.d_model, 2, dtype=self.dtype) *
H
Hui Zhang 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
            -(math.log(10000.0) / self.d_model))
        pe[:, 0::2] = paddle.sin(position * div_term)
        pe[:, 1::2] = paddle.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.pe = pe

    def forward(self, x: paddle.Tensor):
        """Add positional encoding.

        Parameters
        ----------
        x : paddle.Tensor
            Input tensor (batch, time, `*`).

        Returns
        ----------
        paddle.Tensor
            Encoded tensor (batch, time, `*`).
        """
        self.extend_pe(x)
87 88
        T = paddle.shape(x)[1]
        x = x * self.xscale + self.pe[:, :T]
H
Hui Zhang 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
        return self.dropout(x)


class ScaledPositionalEncoding(PositionalEncoding):
    """Scaled positional encoding module.

    See Sec. 3.2  https://arxiv.org/abs/1809.08895

    Parameters
    ----------
        d_model : int
            Embedding dimension.
        dropout_rate : float
            Dropout rate.
        max_len : int
            Maximum input length.
105 106
        dtype : str
            dtype of param
H
Hui Zhang 已提交
107 108
    """

109
    def __init__(self, d_model, dropout_rate, max_len=5000, dtype="float32"):
H
Hui Zhang 已提交
110 111
        """Initialize class."""
        super().__init__(
112 113 114 115 116
            d_model=d_model,
            dropout_rate=dropout_rate,
            max_len=max_len,
            dtype=dtype)
        x = paddle.ones([1], dtype=self.dtype)
H
Hui Zhang 已提交
117 118
        self.alpha = paddle.create_parameter(
            shape=x.shape,
119
            dtype=self.dtype,
H
Hui Zhang 已提交
120 121 122 123
            default_initializer=paddle.nn.initializer.Assign(x))

    def reset_parameters(self):
        """Reset parameters."""
124
        self.alpha = paddle.ones([1])
H
Hui Zhang 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138

    def forward(self, x):
        """Add positional encoding.

        Parameters
        ----------
            x : paddle.Tensor
                Input tensor (batch, time, `*`).
        Returns
        ----------
            paddle.Tensor
                Encoded tensor (batch, time, `*`).
        """
        self.extend_pe(x)
139 140
        T = paddle.shape(x)[1]
        x = x + self.alpha * self.pe[:, :T]
H
Hui Zhang 已提交
141
        return self.dropout(x)