未验证 提交 6b0cc2b1 编写于 作者: L Liu-xiandong 提交者: GitHub

modify sparse_attention docs, test=document_fix (#36554)

* modify sparse_attention docs, test=develop

* add warning

* add warning ,test=document_fix
上级 31cd9145
...@@ -30,7 +30,7 @@ def sparse_attention(query, ...@@ -30,7 +30,7 @@ def sparse_attention(query,
This operator sparsify the Attention matrix in Transformer module This operator sparsify the Attention matrix in Transformer module
to achieve the effect of reducing memory consumption and computation. to achieve the effect of reducing memory consumption and computation.
The sparse layout is expressed in CSR format and contains two parameters, The sparse layout is expressed in CSR format and contains two parameters,
``offset`` and ``columns``. ``offset`` and ``columns``. The equation is:
.. math:: .. math::
...@@ -40,40 +40,42 @@ def sparse_attention(query, ...@@ -40,40 +40,42 @@ def sparse_attention(query,
The dimensions of the three parameters are the same. The dimensions of the three parameters are the same.
``d`` represents the size of the last dimension of the three parameters. ``d`` represents the size of the last dimension of the three parameters.
Parameters: Warning:
This API is only used in ``CUDA 11.3`` and above versions.
Args:
query(Tensor): The query tensor in the Attention module. query(Tensor): The query tensor in the Attention module.
It's a 4-D tensor with a shape of 4-D tensor with shape:
:math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. [batch_size, num_heads, seq_len, head_dim].
The dtype can be ``float32`` and ``float64``. The dtype can be float32 and float64.
key(Tensor): The key tensor in the Attention module. key(Tensor): The key tensor in the Attention module.
It's a 4-D tensor with a shape of 4-D tensor with shape:
:math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. [batch_size, num_heads, seq_len, head_dim].
The dtype can be ``float32`` and ``float64``. The dtype can be float32 and float64.
value(Tensor): The value tensor in the Attention module. value(Tensor): The value tensor in the Attention module.
It's a 4-D tensor with a shape of 4-D tensor with shape:
:math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. [batch_size, num_heads, seq_len, head_dim].
The dtype can be ``float32`` and ``float64``. The dtype can be float32 and float64.
sparse_csr_offset(Tensor): The sparsity feature in the Attention module sparse_csr_offset(Tensor): The sparsity feature in the Attention module
is expressed in the CSR format, and the offset represents is expressed in the CSR format, and the offset represents
the number of non-zero elements in each row of the matrix. the number of non-zero elements in each row of the matrix.
It's a 3-D tensor with a shape of 3-D tensor with shape:
:math:`[batch\_size, num\_heads, seq\_len + 1]`. [batch_size, num_heads, seq_len + 1].
The dtype should be ``int32``. The dtype should be int32.
sparse_csr_columns(Tensor): The sparsity feature in the Attention module sparse_csr_columns(Tensor): The sparsity feature in the Attention module
is expressed in the CSR format, and the columns represent is expressed in the CSR format, and the columns represent
the column index values of non-zero elements in the matrix. the column index values of non-zero elements in the matrix.
It's a 3-D tensor with a shape of 3-D tensor with shape:
:math:`[batch\_size, num\_heads, sparse\_nnz]`. [batch_size, num_heads, sparse_nnz].
The dtype should be ``int32``. The dtype should be int32.
name(str, optional): The default value is None. Normally there is no need for user name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to to set this property. For more information, please refer to
:ref:`api_guide_Name`. :ref:`api_guide_Name`.
Returns: Returns:
A Tensor which refers to the result in the Attention module. 4-D tensor with shape:
It's a 4-D tensor with a shape of [batch_size, num_heads, seq_len, head_dim].
:math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. The dtype can be float32 or float64.
The dtype can be ``float32`` and ``float64``.
Examples: Examples:
.. code-block:: python .. code-block:: python
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册