# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle from paddle import nn import paddle.nn.functional as F from paddle.nn import Linear from paddle.nn.initializer import XavierUniform as xavier_uniform_ from paddle.nn.initializer import Constant as constant_ from paddle.nn.initializer import XavierNormal as xavier_normal_ zeros_ = constant_(value=0.) ones_ = constant_(value=1.) class MultiheadAttention(nn.Layer): """Allows the model to jointly attend to information from different representation subspaces. See reference: Attention Is All You Need .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) Args: embed_dim: total dimension of the model num_heads: parallel attention layers, or heads """ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" self.scaling = self.head_dim**-0.5 self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) self._reset_parameters() self.conv1 = paddle.nn.Conv2D( in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) self.conv2 = paddle.nn.Conv2D( in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) self.conv3 = paddle.nn.Conv2D( in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) def _reset_parameters(self): xavier_uniform_(self.out_proj.weight) def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, attn_mask=None): """ Inputs of forward function query: [target length, batch size, embed dim] key: [sequence length, batch size, embed dim] value: [sequence length, batch size, embed dim] key_padding_mask: if True, mask padding based on batch size incremental_state: if provided, previous time steps are cashed need_weights: output attn_output_weights static_kv: key and value are static Outputs of forward function attn_output: [target length, batch size, embed dim] attn_output_weights: [batch size, target length, sequence length] """ q_shape = paddle.shape(query) src_shape = paddle.shape(key) q = self._in_proj_q(query) k = self._in_proj_k(key) v = self._in_proj_v(value) q *= self.scaling q = paddle.transpose( paddle.reshape( q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]), [1, 2, 0, 3]) k = paddle.transpose( paddle.reshape( k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]), [1, 2, 0, 3]) v = paddle.transpose( paddle.reshape( v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]), [1, 2, 0, 3]) if key_padding_mask is not None: assert key_padding_mask.shape[0] == q_shape[1] assert key_padding_mask.shape[1] == src_shape[0] attn_output_weights = paddle.matmul(q, paddle.transpose(k, [0, 1, 3, 2])) if attn_mask is not None: attn_mask = paddle.unsqueeze(paddle.unsqueeze(attn_mask, 0), 0) attn_output_weights += attn_mask if key_padding_mask is not None: attn_output_weights = paddle.reshape( attn_output_weights, [q_shape[1], self.num_heads, q_shape[0], src_shape[0]]) key = paddle.unsqueeze(paddle.unsqueeze(key_padding_mask, 1), 2) key = paddle.cast(key, 'float32') y = paddle.full( shape=paddle.shape(key), dtype='float32', fill_value='-inf') y = paddle.where(key == 0., key, y) attn_output_weights += y attn_output_weights = F.softmax( attn_output_weights.astype('float32'), axis=-1, dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype) attn_output_weights = F.dropout( attn_output_weights, p=self.dropout, training=self.training) attn_output = paddle.matmul(attn_output_weights, v) attn_output = paddle.reshape( paddle.transpose(attn_output, [2, 0, 1, 3]), [q_shape[0], q_shape[1], self.embed_dim]) attn_output = self.out_proj(attn_output) return attn_output def _in_proj_q(self, query): query = paddle.transpose(query, [1, 2, 0]) query = paddle.unsqueeze(query, axis=2) res = self.conv1(query) res = paddle.squeeze(res, axis=2) res = paddle.transpose(res, [2, 0, 1]) return res def _in_proj_k(self, key): key = paddle.transpose(key, [1, 2, 0]) key = paddle.unsqueeze(key, axis=2) res = self.conv2(key) res = paddle.squeeze(res, axis=2) res = paddle.transpose(res, [2, 0, 1]) return res def _in_proj_v(self, value): value = paddle.transpose(value, [1, 2, 0]) #(1, 2, 0) value = paddle.unsqueeze(value, axis=2) res = self.conv3(value) res = paddle.squeeze(res, axis=2) res = paddle.transpose(res, [2, 0, 1]) return res