sparse_attention.py 8.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
import paddle
from ...fluid.framework import in_dygraph_mode, default_main_program
from paddle.fluid.layer_helper import LayerHelper
from ...fluid.framework import in_dygraph_mode
from paddle import _C_ops


def sparse_attention(query,
                     key,
                     value,
                     sparse_csr_offset,
                     sparse_csr_columns,
28 29
                     key_padding_mask=None,
                     attn_mask=None,
30 31 32 33 34
                     name=None):
    r"""
    This operator sparsify the Attention matrix in Transformer module
    to achieve the effect of reducing memory consumption and computation. 
    The sparse layout is expressed in CSR format and contains two parameters, 
35
    ``offset`` and ``columns``. The equation is: 
36 37 38 39 40 41 42 43 44

    .. math::

        result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V

    where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. 
    The dimensions of the three parameters are the same. 
    ``d`` represents the size of the last dimension of the three parameters.

45 46 47 48
    Warning:    
        This API is only used in ``CUDA 11.3`` and above versions.

    Args:
49
        query(Tensor): The query tensor in the Attention module. 
50 51 52
                        4-D tensor with shape: 
                        [batch_size, num_heads, seq_len, head_dim]. 
                        The dtype can be float32 and float64.
53
        key(Tensor): The key tensor in the Attention module. 
54 55 56
                        4-D tensor with shape: 
                        [batch_size, num_heads, seq_len, head_dim]. 
                        The dtype can be float32 and float64.
57
        value(Tensor): The value tensor in the Attention module. 
58 59 60
                        4-D tensor with shape:  
                        [batch_size, num_heads, seq_len, head_dim]. 
                        The dtype can be float32 and float64.
61 62 63
        sparse_csr_offset(Tensor): The sparsity feature in the Attention module 
                        is expressed in the CSR format, and the offset represents 
                        the number of non-zero elements in each row of the matrix.
64 65 66
                        3-D tensor with shape:   
                        [batch_size, num_heads, seq_len + 1]. 
                        The dtype should be int32.
67 68 69
        sparse_csr_columns(Tensor): The sparsity feature in the Attention module 
                        is expressed in the CSR format, and the columns represent 
                        the column index values of non-zero elements in the matrix.
70 71 72
                        3-D tensor with shape:  
                        [batch_size, num_heads, sparse_nnz]. 
                        The dtype should be int32.
73 74 75 76 77 78 79 80
        key_padding_mask(Tensor, optional):The key padding mask tensor in the Attention module. 
                        2-D tensor with shape: [batch_size, seq_len]. 
                        The dtype can be float32 and float64.
                        A value of 0 means that the position is masked.
        attn_mask(Tensor, optional):The attention mask tensor in the Attention module. 
                        2-D tensor with shape: [seq_len, seq_len]. 
                        The dtype can be float32 and float64.
                        A value of 0 means that the position is masked.
81 82 83 84 85
        name(str, optional): The default value is None. Normally there is no need for user
                        to set this property. For more information, please refer to
                        :ref:`api_guide_Name`.

    Returns:
86 87 88
        4-D tensor with shape:
        [batch_size, num_heads, seq_len, head_dim]. 
        The dtype can be float32 or float64.
89 90 91 92 93 94 95

    Examples:
        .. code-block:: python

            # required: skiptest
            import paddle
            import numpy as np
96

97 98 99 100 101 102 103 104 105 106
            query_data = np.array([[[[0, 1,], [2, 3],
                    [ 0, 1], [2, 3]]]]).astype("float32")
            key_data = np.array([[[[0, 1,], [2, 3],
                            [ 0, 1], [2, 3]]]]).astype("float32")
            value_data = np.array([[[[0, 1,], [2, 3],
                            [ 0, 1], [2, 3]]]]).astype("float32")
            sparse_csr_offset_data = np.array([[[0, 2,
                            4, 6, 8]]]).astype("int32")
            sparse_csr_columns_data = np.array([[[0, 1,
                            0, 1, 2, 3, 2, 3]]]).astype("int32")
107 108
            key_padding_mask_data = np.array([[1,1,1,0]]).astype("float32")
            attention_mask_data = np.array([[1,0,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1]]).astype("float32")
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
            print(query_data.shape)
            # (1, 1, 4, 2)
            print(sparse_csr_offset_data.shape)
            # (1, 1, 5)
            print(sparse_csr_columns_data.shape)
            # (1, 1, 8)
            paddle.disable_static()
            query = paddle.to_tensor(query_data, stop_gradient=False, 
                            place=paddle.CUDAPlace(0))
            key = paddle.to_tensor(key_data, stop_gradient=False, 
                            place=paddle.CUDAPlace(0))
            value = paddle.to_tensor(value_data, stop_gradient=False, 
                            place=paddle.CUDAPlace(0))
            offset = paddle.to_tensor(sparse_csr_offset_data, stop_gradient=False, 
                            place=paddle.CUDAPlace(0))
            columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False, 
                            place=paddle.CUDAPlace(0))
126 127 128 129 130 131 132 133 134 135 136 137
            key_padding_mask = paddle.to_tensor(key_padding_mask_data, stop_gradient=False, 
                            place=paddle.CUDAPlace(0))
            attention_mask = paddle.to_tensor(attention_mask_data, stop_gradient=False, 
                            place=paddle.CUDAPlace(0))
            output_mask = paddle.nn.functional.sparse_attention(query, key, 
                            value, offset, columns, 
                            key_padding_mask=key_padding_mask, attn_mask=attention_mask)
            print(output_mask)
            # [[[[0.        , 1.        ],
            #    [1.99830270, 2.99830270],
            #    [0.        , 1.        ],
            #    [0.        , 1.        ]]]]
138 139
            output = paddle.nn.functional.sparse_attention(query, key, 
                            value, offset, columns)
140
            print(output) 
141 142 143 144 145 146 147
            # [[[[1.60885942, 2.60885954],
            #       [1.99830270, 2.99830270],
            #       [1.60885942, 2.60885954],
            #       [1.99830270, 2.99830270]]]]
    """
    if in_dygraph_mode():
        result_attention, result_sdd, result_softmax = _C_ops.sparse_attention(
148 149
            query, key, value, sparse_csr_offset, sparse_csr_columns,
            key_padding_mask, attn_mask)
150 151 152 153 154 155 156 157 158 159 160 161
        return result_attention

    helper = LayerHelper('sparse_attention', **locals())
    dtype = helper.input_dtype(input_param_name='Q')
    out = helper.create_variable_for_type_inference(dtype)
    result_sdd = helper.create_variable_for_type_inference(dtype)
    result_softmax = helper.create_variable_for_type_inference(dtype)
    inputs = {
        'Q': query,
        'K': key,
        'V': value,
        'Offset': sparse_csr_offset,
162 163 164
        'Columns': sparse_csr_columns,
        'KeyPaddingMask': key_padding_mask,
        'AttnMask': attn_mask,
165 166 167 168 169 170 171 172
    }
    outputs = {
        'Out': out,
        'SparseDotSdd': result_sdd,
        'Softmax': result_softmax
    }
    helper.append_op(type='sparse_attention', inputs=inputs, outputs=outputs)
    return out