multihead_attention.py 3.9 KB
Newer Older
S
sserdoubleh 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MultiheadAttention class.
"""

import paddle.fluid as fluid
from paddle.fluid.dygraph import Layer
from paddle.fluid.dygraph import FC
import paddle.fluid.layers as layers

import plato.modules.functions as F


class MultiheadAttention(Layer):
    """
    Multi head attention layer.
    """

    def __init__(self, name_scope, hidden_dim, num_heads, dropout):
        assert hidden_dim % num_heads == 0
        super().__init__(name_scope)

        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.linear_qkv = FC(name_scope=self.full_name(),
                             size=hidden_dim * 3,
                             num_flatten_dims=2)
        self.linear_out = FC(name_scope=self.full_name(),
                             size=hidden_dim,
                             num_flatten_dims=2)
        self.dropout = dropout
        return

    def _split_heads(self, x, is_key=False):
        x = layers.reshape(
            x=x, shape=[0, 0, self.num_heads, self.head_dim]
        )
        x = layers.transpose(x=x, perm=[0, 2, 3, 1] if is_key else [0, 2, 1, 3])
        return x

    def _merge_heads(self, x):
        x = layers.transpose(x=x, perm=[0, 2, 1, 3])
        x = layers.reshape(x=x, shape=[0, 0, self.hidden_dim])
        return x

    def _attn(self, query, key, value, mask):
        # shape: [batch_size, num_head, seq_len, seq_len]
        scores = layers.matmul(x=query, y=key, alpha=self.scale)

        if mask is not None:
            mask = F.unsqueeze(mask, [1])
            mask = layers.expand(mask, [1, self.num_heads, 1, 1])
            mask.stop_gradient = True
            scores = (1 - mask) * scores + layers.scale(mask, scale=-1e10)

        attn = layers.softmax(scores, axis=-1)
        attn = F.dropout(attn, self.dropout)

        if mask is not None:
            attn = (1 - mask) * attn

        out = layers.matmul(x=attn, y=value)
        return out

    def forward(self, inp, mask=None, cache=None):
        """ Forward process of self attention. """
        # shape: [batch_size, seq_len, 3 * hidden_dim]
        qkv = self.linear_qkv(inp)
        query, key, value = layers.split(qkv, num_or_sections=3, dim=2)


        # shape: [batch_size, num_head, seq_len, head_dim]
        query = self._split_heads(query)
        # shape: [batch_size, num_head, head_dim, seq_len]
        key = self._split_heads(key, is_key=True)
        # shape: [batch_size, num_head, seq_len, head_dim]
        value = self._split_heads(value)

        if cache is not None:
            if "key" in cache and "value" in cache:
                key = layers.concat([cache["key"], key], axis=3)
                value = layers.concat([cache["value"], value], axis=2)
            cache["key"] = key
            cache["value"] = value

        out = self._attn(query, key, value, mask)
        out = self._merge_heads(out)
        out = self.linear_out(out)
        return out


def main():
    import numpy as np

    place = fluid.CPUPlace()
    with fluid.dygraph.guard(place):
        model = MultiheadAttention("MultiheadAttention", 10, 2, 0.5)
        inp = np.random.rand(2, 3, 10).astype("float32")
        inp = fluid.dygraph.to_variable(inp)
        out = model(inp, inp, inp)
        print(out)


if __name__ == "__main__":
    main()