@@ -19,6 +19,7 @@ from paddle.framework import LayerHelper, in_dynamic_mode
...
@@ -19,6 +19,7 @@ from paddle.framework import LayerHelper, in_dynamic_mode
defmasked_multihead_attention(
defmasked_multihead_attention(
x,
x,
cache_kv=None,
cache_kv=None,
bias=None,
src_mask=None,
src_mask=None,
cum_offsets=None,
cum_offsets=None,
sequence_lengths=None,
sequence_lengths=None,
...
@@ -30,6 +31,7 @@ def masked_multihead_attention(
...
@@ -30,6 +31,7 @@ def masked_multihead_attention(
seq_len=1,
seq_len=1,
rotary_emb_dims=0,
rotary_emb_dims=0,
use_neox_rotary_style=False,
use_neox_rotary_style=False,
compute_dtype='default',
out_scale=-1,
out_scale=-1,
quant_round_type=1,
quant_round_type=1,
quant_max_bound=127.0,
quant_max_bound=127.0,
...
@@ -43,6 +45,7 @@ def masked_multihead_attention(
...
@@ -43,6 +45,7 @@ def masked_multihead_attention(
Args:
Args:
x (Tensor): The input tensor could be 2-D tensor. Its shape is [batch_size, 3 * num_head * head_dim].
x (Tensor): The input tensor could be 2-D tensor. Its shape is [batch_size, 3 * num_head * head_dim].
cache_kvs (list(Tensor)|tuple(Tensor)): The cache structure tensors for the generation model. Its shape is [2, batch_size, num_head, max_seq_len, head_dim].
cache_kvs (list(Tensor)|tuple(Tensor)): The cache structure tensors for the generation model. Its shape is [2, batch_size, num_head, max_seq_len, head_dim].
bias (Tensor, optional): The bias tensor. Its shape is [3, num_head, head_dim].
src_mask (Tensor, optional): The src_mask tensor. Its shape is [batch_size, 1, 1, sequence_length].
src_mask (Tensor, optional): The src_mask tensor. Its shape is [batch_size, 1, 1, sequence_length].
sequence_lengths (Tensor, optional): The sequence_lengths tensor, used to index input. Its shape is [batch_size, 1].
sequence_lengths (Tensor, optional): The sequence_lengths tensor, used to index input. Its shape is [batch_size, 1].
rotary_tensor (Tensor, optional): The rotary_tensor tensor. The dtype must be float. Its shape is [batch_size, 1, 1, sequence_length, head_dim].
rotary_tensor (Tensor, optional): The rotary_tensor tensor. The dtype must be float. Its shape is [batch_size, 1, 1, sequence_length, head_dim].
...
@@ -53,6 +56,7 @@ def masked_multihead_attention(
...
@@ -53,6 +56,7 @@ def masked_multihead_attention(
seq_len (int, optional): The seq_len, used to get input length. Default 1.
seq_len (int, optional): The seq_len, used to get input length. Default 1.
rotary_emb_dims (int, optional): The rotary_emb_dims. Default 1.
rotary_emb_dims (int, optional): The rotary_emb_dims. Default 1.
use_neox_rotary_style (bool, optional): A flag indicating whether neox_rotary_style is needed or not. Default False.
use_neox_rotary_style (bool, optional): A flag indicating whether neox_rotary_style is needed or not. Default False.
compute_dtype (string): A compute dtype, used to represent the input data type.
out_scale (float, optional): The out_scale, used in quant.
out_scale (float, optional): The out_scale, used in quant.
quant_round_type (int, optional): The quant_round_type, used in quant. Default 1.
quant_round_type (int, optional): The quant_round_type, used in quant. Default 1.
quant_max_bound (float, optional): The quant_max_bound, used in quant. Default 127.0.
quant_max_bound (float, optional): The quant_max_bound, used in quant. Default 127.0.