q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if q must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
k (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
k (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
v (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
v (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
Returns:
Returns:
out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` .
out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` .