# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle from paddle import _C_ops, in_dynamic_mode from paddle.fluid.layer_helper import LayerHelper def flash_attention( query, key, value, dropout=0.0, causal=False, return_softmax=False, name=None, ): r""" The equation is: .. math:: result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. The dimensions of the three parameters are the same. ``d`` represents the size of the last dimension of the three parameters. Warning: This API is only support inputs with dtype float16 and bfloat16. Args: query(Tensor): The query tensor in the Attention module. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16. key(Tensor): The key tensor in the Attention module. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16. value(Tensor): The value tensor in the Attention module. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16. dropout(float): The dropout ratio. causal(bool): Wether enable causal mode. return_softmax(bool): Wether to return softmax. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: out(Tensor): The attention tensor. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. softmax(Tensor): The softmax tensor. None if return_softmax is False. Examples: .. code-block:: python # required: skiptest import paddle q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16) output = paddle.nn.functional.flash_attention(q, q, q, 0.9, False, False) print(output) """ if in_dynamic_mode(): (result_attention, result_softmax,) = _C_ops.flash_attn( query, key, value, dropout, causal, return_softmax, ) return result_attention, result_softmax helper = LayerHelper('flash_attn', **locals()) dtype = helper.input_dtype(input_param_name='q') out = helper.create_variable_for_type_inference(dtype) softmax = helper.create_variable_for_type_inference(dtype) softmax_lse = helper.create_variable_for_type_inference(paddle.float32) seed_offset = helper.create_variable_for_type_inference(paddle.int64) inputs = { 'q': query, 'k': key, 'v': value, } outputs = { 'out': out, 'softmax': softmax, 'softmax_lse': softmax_lse, 'seed_offset': seed_offset, } helper.append_op( type='flash_attn', inputs=inputs, outputs=outputs, attrs={ 'dropout': dropout, 'causal': causal, 'return_softmax': return_softmax, }, ) return out, softmax def flash_attn_unpadded( query, key, value, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, scale, dropout=0.0, causal=False, return_softmax=False, name=None, ): r""" The equation is: .. math:: result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. The dimensions of the three parameters are the same. ``d`` represents the size of the last dimension of the three parameters. Warning: This API is only support inputs with dtype float16 and bfloat16. Args: query(Tensor): The query tensor in the Attention module. 3-D tensor with shape: [total_seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16. key(Tensor): The key tensor in the Attention module. 3-D tensor with shape: [total_seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16. value(Tensor): The value tensor in the Attention module. 3-D tensor with shape: [total_seq_len, num_heads, head_dim]. The dtype can be float61 or bfloat16. cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch, used to index query. cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch, used to index key and value. max_seqlen_q(int): Maximum sequence length of query in the batch. max_seqlen_k(int): Maximum sequence length of key/value in the batch. scale(float): The scaling of QK^T before applying softmax. dropout(float): The dropout ratio. causal(bool): Wether enable causal mode. return_softmax(bool): Wether to return softmax. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: out(Tensor): The attention tensor. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. softmax(Tensor): The softmax tensor. None if return_softmax is False. Examples: .. code-block:: python # required: skiptest import paddle q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16) output = paddle.nn.functional.flash_attn_unpadded(q, q, q, 0.9, False, False) print(output) """ if in_dynamic_mode(): (result_attention, result_softmax,) = _C_ops.flash_attn_unpadded( query, key, value, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, scale, dropout, causal, return_softmax, ) return result_attention, result_softmax helper = LayerHelper('flash_attn_unpadded', **locals()) dtype = helper.input_dtype(input_param_name='q') out = helper.create_variable_for_type_inference(dtype) softmax = helper.create_variable_for_type_inference(dtype) softmax_lse = helper.create_variable_for_type_inference(paddle.float32) seed_offset = helper.create_variable_for_type_inference(paddle.int64) inputs = { 'q': query, 'k': key, 'v': value, 'cu_seqlens_q': cu_seqlens_q, 'cu_seqlens_k': cu_seqlens_k, } outputs = { 'out': out, 'softmax': softmax, 'softmax_lse': softmax_lse, 'seed_offset': seed_offset, } helper.append_op( type='flash_attn_unpadded', inputs=inputs, outputs=outputs, attrs={ 'max_seqlen_q': max_seqlen_q, 'max_seqlen_k': max_seqlen_k, 'scale': scale, 'dropout': dropout, 'causal': causal, 'return_softmax': return_softmax, }, ) return out, softmax