x (Tensor): The input tensor of fused_multi_head_attention. The shape is
x (Tensor): The input tensor of fused_multi_head_attention. The shape is
`[batch\_size, sequence\_len, embed\_dim]`.
`[batch\_size, sequence\_len, embed\_dim]`.
qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`.
qkv_weight (Tensor): The qkv weight tensor. If `transpose_qkv_wb` is False, the shape is `[3, num_head, dim_head, dim_embed]`. Otherwise, the shape is `[dim_embed, 3 * dim_embed]`.
linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`.
linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`.
pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture
pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture
ln_bias (Tensor, optional): The bias tensor of layernorm. Default None.
ln_bias (Tensor, optional): The bias tensor of layernorm. Default None.
pre_ln_epsilon (float, optional): Small float value added to denominator of the pre layer_norm
pre_ln_epsilon (float, optional): Small float value added to denominator of the pre layer_norm
to avoid dividing by zero. Default is 1e-5.
to avoid dividing by zero. Default is 1e-5.
qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`.
qkv_bias (Tensor, optional): The bias of qkv computation. If `transpose_qkv_wb` is False, the shape is `[3, num_head, dim_head]`. Otherwise, the shape is `[3 * dim_embed]`.
Default None.
Default None.
linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
cache_kv (Tensor, optional): For generation model, cache structure. The shape is `[2, bsz, num_head, seq_len, head_dim]`. Default None.
cache_kv (Tensor, optional): For generation model, cache structure. The shape is `[2, bsz, num_head, seq_len, head_dim]`. Default None.
add_residual (bool, optional): Whether add residual at the end. Default is True.
add_residual (bool, optional): Whether add residual at the end. Default is True.
num_heads (int, optional): If enable transpose_qkv_wb, should provide the num_heads. Default is -1, means not transpose qkv wb.
num_heads (int, optional): If enable transpose_qkv_wb, should provide the num_heads. Default is -1, means not transpose qkv wb.
transpose_qkv_wb (bool, optional): Whether transpose the qkv_weight and qkv_bias in the op. Only support GPU for now. Default is false, means not transpose qkv wb.
transpose_qkv_wb (bool, optional): Whether transpose the qkv_weight and qkv_bias in the op. Only support GPU for now. Default is false, means not transpose qkv wb.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
Returns:
Tensor: The output Tensor, the data type and shape is same as `x`.
Tensor: The output Tensor, the data type and shape is same as `x`.