sparse_attention.py 7.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
import paddle
Z
zhiboniu 已提交
17
from ...fluid.framework import default_main_program
18
from paddle.fluid.layer_helper import LayerHelper
19
from paddle import _C_ops, _legacy_C_ops
Z
zhiboniu 已提交
20
from paddle import in_dynamic_mode
21 22


L
Ligoml 已提交
23 24 25 26 27 28 29 30 31 32
def sparse_attention(
    query,
    key,
    value,
    sparse_csr_offset,
    sparse_csr_columns,
    key_padding_mask=None,
    attn_mask=None,
    name=None,
):
33 34
    r"""
    This operator sparsify the Attention matrix in Transformer module
L
Ligoml 已提交
35 36 37
    to achieve the effect of reducing memory consumption and computation.
    The sparse layout is expressed in CSR format and contains two parameters,
    ``offset`` and ``columns``. The equation is:
38 39 40 41 42

    .. math::

        result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V

L
Ligoml 已提交
43 44
    where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
    The dimensions of the three parameters are the same.
45 46
    ``d`` represents the size of the last dimension of the three parameters.

L
Ligoml 已提交
47
    Warning:
48 49 50
        This API is only used in ``CUDA 11.3`` and above versions.

    Args:
L
Ligoml 已提交
51 52 53
        query(Tensor): The query tensor in the Attention module.
                        4-D tensor with shape:
                        [batch_size, num_heads, seq_len, head_dim].
54
                        The dtype can be float32 and float64.
L
Ligoml 已提交
55 56 57
        key(Tensor): The key tensor in the Attention module.
                        4-D tensor with shape:
                        [batch_size, num_heads, seq_len, head_dim].
58
                        The dtype can be float32 and float64.
L
Ligoml 已提交
59 60 61
        value(Tensor): The value tensor in the Attention module.
                        4-D tensor with shape:
                        [batch_size, num_heads, seq_len, head_dim].
62
                        The dtype can be float32 and float64.
L
Ligoml 已提交
63 64
        sparse_csr_offset(Tensor): The sparsity feature in the Attention module
                        is expressed in the CSR format, and the offset represents
65
                        the number of non-zero elements in each row of the matrix.
L
Ligoml 已提交
66 67
                        3-D tensor with shape:
                        [batch_size, num_heads, seq_len + 1].
68
                        The dtype should be int32.
L
Ligoml 已提交
69 70
        sparse_csr_columns(Tensor): The sparsity feature in the Attention module
                        is expressed in the CSR format, and the columns represent
71
                        the column index values of non-zero elements in the matrix.
L
Ligoml 已提交
72 73
                        3-D tensor with shape:
                        [batch_size, num_heads, sparse_nnz].
74
                        The dtype should be int32.
L
Ligoml 已提交
75 76
        key_padding_mask(Tensor, optional):The key padding mask tensor in the Attention module.
                        2-D tensor with shape: [batch_size, seq_len].
77 78
                        The dtype can be float32 and float64.
                        A value of 0 means that the position is masked.
L
Ligoml 已提交
79 80
        attn_mask(Tensor, optional):The attention mask tensor in the Attention module.
                        2-D tensor with shape: [seq_len, seq_len].
81 82
                        The dtype can be float32 and float64.
                        A value of 0 means that the position is masked.
83 84 85 86 87
        name(str, optional): The default value is None. Normally there is no need for user
                        to set this property. For more information, please refer to
                        :ref:`api_guide_Name`.

    Returns:
88
        4-D tensor with shape:
L
Ligoml 已提交
89
        [batch_size, num_heads, seq_len, head_dim].
90
        The dtype can be float32 or float64.
91 92 93 94 95 96 97

    Examples:
        .. code-block:: python

            # required: skiptest
            import paddle
            import numpy as np
98

99 100 101 102 103 104 105 106 107 108
            query_data = np.array([[[[0, 1,], [2, 3],
                    [ 0, 1], [2, 3]]]]).astype("float32")
            key_data = np.array([[[[0, 1,], [2, 3],
                            [ 0, 1], [2, 3]]]]).astype("float32")
            value_data = np.array([[[[0, 1,], [2, 3],
                            [ 0, 1], [2, 3]]]]).astype("float32")
            sparse_csr_offset_data = np.array([[[0, 2,
                            4, 6, 8]]]).astype("int32")
            sparse_csr_columns_data = np.array([[[0, 1,
                            0, 1, 2, 3, 2, 3]]]).astype("int32")
109 110
            key_padding_mask_data = np.array([[1,1,1,0]]).astype("float32")
            attention_mask_data = np.array([[1,0,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1]]).astype("float32")
111 112 113 114 115 116 117
            print(query_data.shape)
            # (1, 1, 4, 2)
            print(sparse_csr_offset_data.shape)
            # (1, 1, 5)
            print(sparse_csr_columns_data.shape)
            # (1, 1, 8)
            paddle.disable_static()
L
Ligoml 已提交
118
            query = paddle.to_tensor(query_data, stop_gradient=False,
119
                            place=paddle.CUDAPlace(0))
L
Ligoml 已提交
120
            key = paddle.to_tensor(key_data, stop_gradient=False,
121
                            place=paddle.CUDAPlace(0))
L
Ligoml 已提交
122
            value = paddle.to_tensor(value_data, stop_gradient=False,
123
                            place=paddle.CUDAPlace(0))
L
Ligoml 已提交
124
            offset = paddle.to_tensor(sparse_csr_offset_data, stop_gradient=False,
125
                            place=paddle.CUDAPlace(0))
L
Ligoml 已提交
126
            columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False,
127
                            place=paddle.CUDAPlace(0))
L
Ligoml 已提交
128
            key_padding_mask = paddle.to_tensor(key_padding_mask_data, stop_gradient=False,
129
                            place=paddle.CUDAPlace(0))
L
Ligoml 已提交
130
            attention_mask = paddle.to_tensor(attention_mask_data, stop_gradient=False,
131
                            place=paddle.CUDAPlace(0))
L
Ligoml 已提交
132 133
            output_mask = paddle.nn.functional.sparse_attention(query, key,
                            value, offset, columns,
134 135 136 137 138 139
                            key_padding_mask=key_padding_mask, attn_mask=attention_mask)
            print(output_mask)
            # [[[[0.        , 1.        ],
            #    [1.99830270, 2.99830270],
            #    [0.        , 1.        ],
            #    [0.        , 1.        ]]]]
L
Ligoml 已提交
140
            output = paddle.nn.functional.sparse_attention(query, key,
141
                            value, offset, columns)
L
Ligoml 已提交
142
            print(output)
143 144 145 146 147
            # [[[[1.60885942, 2.60885954],
            #       [1.99830270, 2.99830270],
            #       [1.60885942, 2.60885954],
            #       [1.99830270, 2.99830270]]]]
    """
Z
zhiboniu 已提交
148
    if in_dynamic_mode():
L
Ligoml 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161
        (
            result_attention,
            result_sdd,
            result_softmax,
        ) = _legacy_C_ops.sparse_attention(
            query,
            key,
            value,
            sparse_csr_offset,
            sparse_csr_columns,
            key_padding_mask,
            attn_mask,
        )
162 163 164 165 166 167 168 169 170 171 172 173
        return result_attention

    helper = LayerHelper('sparse_attention', **locals())
    dtype = helper.input_dtype(input_param_name='Q')
    out = helper.create_variable_for_type_inference(dtype)
    result_sdd = helper.create_variable_for_type_inference(dtype)
    result_softmax = helper.create_variable_for_type_inference(dtype)
    inputs = {
        'Q': query,
        'K': key,
        'V': value,
        'Offset': sparse_csr_offset,
174 175 176
        'Columns': sparse_csr_columns,
        'KeyPaddingMask': key_padding_mask,
        'AttnMask': attn_mask,
177 178 179 180
    }
    outputs = {
        'Out': out,
        'SparseDotSdd': result_sdd,
L
Ligoml 已提交
181
        'Softmax': result_softmax,
182 183 184
    }
    helper.append_op(type='sparse_attention', inputs=inputs, outputs=outputs)
    return out