From 6b0cc2b187ef41f87fe6f13760e2e88c16d46cb0 Mon Sep 17 00:00:00 2001 From: Liu-xiandong <85323580+Liu-xiandong@users.noreply.github.com> Date: Mon, 15 Nov 2021 10:03:52 +0800 Subject: [PATCH] modify sparse_attention docs, test=document_fix (#36554) * modify sparse_attention docs, test=develop * add warning * add warning ,test=document_fix --- .../paddle/nn/functional/sparse_attention.py | 44 ++++++++++--------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/python/paddle/nn/functional/sparse_attention.py b/python/paddle/nn/functional/sparse_attention.py index f57669f1145..b98e8142f45 100644 --- a/python/paddle/nn/functional/sparse_attention.py +++ b/python/paddle/nn/functional/sparse_attention.py @@ -30,7 +30,7 @@ def sparse_attention(query, This operator sparsify the Attention matrix in Transformer module to achieve the effect of reducing memory consumption and computation. The sparse layout is expressed in CSR format and contains two parameters, - ``offset`` and ``columns``. + ``offset`` and ``columns``. The equation is: .. math:: @@ -40,40 +40,42 @@ def sparse_attention(query, The dimensions of the three parameters are the same. ``d`` represents the size of the last dimension of the three parameters. - Parameters: + Warning: + This API is only used in ``CUDA 11.3`` and above versions. + + Args: query(Tensor): The query tensor in the Attention module. - It's a 4-D tensor with a shape of - :math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. - The dtype can be ``float32`` and ``float64``. + 4-D tensor with shape: + [batch_size, num_heads, seq_len, head_dim]. + The dtype can be float32 and float64. key(Tensor): The key tensor in the Attention module. - It's a 4-D tensor with a shape of - :math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. - The dtype can be ``float32`` and ``float64``. + 4-D tensor with shape: + [batch_size, num_heads, seq_len, head_dim]. + The dtype can be float32 and float64. value(Tensor): The value tensor in the Attention module. - It's a 4-D tensor with a shape of - :math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. - The dtype can be ``float32`` and ``float64``. + 4-D tensor with shape: + [batch_size, num_heads, seq_len, head_dim]. + The dtype can be float32 and float64. sparse_csr_offset(Tensor): The sparsity feature in the Attention module is expressed in the CSR format, and the offset represents the number of non-zero elements in each row of the matrix. - It's a 3-D tensor with a shape of - :math:`[batch\_size, num\_heads, seq\_len + 1]`. - The dtype should be ``int32``. + 3-D tensor with shape: + [batch_size, num_heads, seq_len + 1]. + The dtype should be int32. sparse_csr_columns(Tensor): The sparsity feature in the Attention module is expressed in the CSR format, and the columns represent the column index values of non-zero elements in the matrix. - It's a 3-D tensor with a shape of - :math:`[batch\_size, num\_heads, sparse\_nnz]`. - The dtype should be ``int32``. + 3-D tensor with shape: + [batch_size, num_heads, sparse_nnz]. + The dtype should be int32. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: - A Tensor which refers to the result in the Attention module. - It's a 4-D tensor with a shape of - :math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. - The dtype can be ``float32`` and ``float64``. + 4-D tensor with shape: + [batch_size, num_heads, seq_len, head_dim]. + The dtype can be float32 or float64. Examples: .. code-block:: python -- GitLab