未验证 提交 581d05bb 编写于 作者: T tianhaodongbd 提交者: GitHub

add sin and cos optional parameters to fused_rope op (#55415)

上级 18de0c94
......@@ -17,10 +17,10 @@
support_dygraph_mode : true
- backward_op : fused_rotary_position_embedding_grad
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor sin, Tensor cos, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
optional : out_k_grad, out_v_grad, k_grad, v_grad
optional : sin, cos, out_k_grad, out_v_grad, k_grad, v_grad
infer_meta :
func : FusedRopeGradInferMeta
kernel :
......
......@@ -109,11 +109,11 @@
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index
- op : fused_rotary_position_embedding
args : (Tensor q, Tensor k, Tensor v)
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos)
output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
infer_meta :
func : FusedRopeInferMeta
optional : k,v, out_k, out_v
optional : k,v,sin,cos, out_k, out_v
kernel :
func : fused_rotary_position_embedding
data_type : q
......
......@@ -1217,7 +1217,9 @@ void IndexPutGradInferMeta(const MetaTensor& x,
}
}
void FusedRopeGradInferMeta(const MetaTensor& dout_q,
void FusedRopeGradInferMeta(const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
MetaTensor* dq,
......
......@@ -184,7 +184,9 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
MetaTensor* x_grad,
MetaTensor* y_grad);
void FusedRopeGradInferMeta(const MetaTensor& dout_q,
void FusedRopeGradInferMeta(const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
MetaTensor* dq,
......
......@@ -3617,6 +3617,8 @@ void FusedConvInferMeta(const MetaTensor& input,
void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
const MetaTensor& sin,
const MetaTensor& cos,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v) {
......
......@@ -717,6 +717,8 @@ void WeightOnlyMatmulInferMeta(const MetaTensor& x,
void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
const MetaTensor& sin,
const MetaTensor& cos,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v);
......
......@@ -18,68 +18,14 @@
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h"
namespace phi {
namespace fusion {
template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeGradKernel(phi::Array<const T*, 3> ins_data,
int batch_size,
int seq_len,
int num_heads,
int head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int stride = gridDim.x * blockDim.x * VecSize;
int size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize];
MPType cos_value[VecSize];
MPType result[VecSize];
T store[VecSize];
using VecType = phi::AlignedVector<T, VecSize>;
constexpr int kVectorsPerThread = VecSize / 2;
for (; index < size; index += stride) {
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c));
MPType value = pos_seq * indicses;
sin_value[nx] = sin(value);
cos_value[nx] = cos(value);
}
#pragma unroll
for (int iter = 0; iter < 3; iter++) {
if (iter > num_inputs) break;
const T* input = ins_data[iter] + index;
VecType* out = reinterpret_cast<VecType*>(outs_data[iter] + index);
#pragma unroll
for (int nx = 0; nx < kVectorsPerThread; ++nx) {
int pr_index = nx * 2;
int ls_index = pr_index + 1;
MPType p0 = static_cast<MPType>(input[pr_index]);
MPType p1 = static_cast<MPType>(input[ls_index]);
result[pr_index] = cos_value[pr_index] * p0 + sin_value[ls_index] * p1;
result[ls_index] = cos_value[ls_index] * p1 - sin_value[pr_index] * p0;
store[pr_index] = static_cast<T>(result[pr_index]);
store[ls_index] = static_cast<T>(result[ls_index]);
}
out[0] = *(reinterpret_cast<VecType*>(store));
}
}
}
template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& sin,
const paddle::optional<DenseTensor>& cos,
const DenseTensor& dout_q,
const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v,
......@@ -111,6 +57,7 @@ void FusedRopeGradKernel(const Context& dev_ctx,
phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;
phi::Array<const T*, 2> sin_cos_data;
ins_data[0] = dout_q.data<T>();
outs_data[0] = dq->data<T>();
......@@ -135,8 +82,20 @@ void FusedRopeGradKernel(const Context& dev_ctx,
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType div_c = static_cast<MPType>(1.0f / head_dim);
VectorizedFusedRopeGradKernel<T, MPType, vec_size>
bool flag_sin_cos = false;
if (sin.get_ptr() && cos.get_ptr()) {
sin_cos_data[0] = sin->data<T>();
sin_cos_data[1] = cos->data<T>();
flag_sin_cos = true;
}
int sign = -1;
VectorizedFusedRopeKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
......
......@@ -18,76 +18,17 @@
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h"
namespace phi {
namespace fusion {
template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
int batch_size,
int seq_len,
int num_heads,
int head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int stride = gridDim.x * blockDim.x * VecSize;
int size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize];
MPType cos_value[VecSize];
MPType result[VecSize];
T store[VecSize];
using VecType = phi::AlignedVector<T, VecSize>;
constexpr int kVectorsPerThread = VecSize / 2;
for (; index < size; index += stride) {
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c));
MPType value = pos_seq * indicses;
sin_value[nx] = sin(value);
cos_value[nx] = cos(value);
}
#pragma unroll
for (int iter = 0; iter < 3; iter++) {
if (iter > num_inputs) break;
const T* input = ins_data[iter] + index;
VecType* out = reinterpret_cast<VecType*>(outs_data[iter] + index);
#pragma unroll
for (int nx = 0; nx < kVectorsPerThread; ++nx) {
int pr_index = nx * 2;
int ls_index = pr_index + 1;
MPType p0 = static_cast<MPType>(input[pr_index]);
MPType p1 = static_cast<MPType>(input[ls_index]);
result[pr_index] = cos_value[pr_index] * p0;
result[pr_index] -= sin_value[pr_index] * p1;
result[ls_index] = sin_value[ls_index] * p0;
result[ls_index] += cos_value[ls_index] * p1;
store[pr_index] = static_cast<T>(result[pr_index]);
store[ls_index] = static_cast<T>(result[ls_index]);
}
out[0] = *(reinterpret_cast<VecType*>(store));
}
}
}
template <typename T, typename Context>
void FusedRopeKernel(const Context& dev_ctx,
const DenseTensor& q,
const paddle::optional<DenseTensor>& k,
const paddle::optional<DenseTensor>& v,
const paddle::optional<DenseTensor>& sin,
const paddle::optional<DenseTensor>& cos,
DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v) {
......@@ -116,6 +57,7 @@ void FusedRopeKernel(const Context& dev_ctx,
phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;
phi::Array<const T*, 2> sin_cos_data;
ins_data[0] = q.data<T>();
outs_data[0] = out_q->data<T>();
......@@ -140,8 +82,45 @@ void FusedRopeKernel(const Context& dev_ctx,
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType div_c = static_cast<MPType>(1.0f / head_dim);
bool flag_sin_cos = false;
if (sin.get_ptr() && cos.get_ptr()) {
PADDLE_ENFORCE_EQ(sin.get_ptr()->dims(),
cos.get_ptr()->dims(),
phi::errors::InvalidArgument(
"The dims of sin and cos must be the same."));
auto sin_dims = sin.get_ptr()->dims();
int dims_size = sin_dims.size();
PADDLE_ENFORCE_NE((dims_size == 2 || dims_size == 4),
false,
phi::errors::InvalidArgument(
"The dims of sin and cos must be 2 or 4."));
if (dims_size == 4) {
PADDLE_ENFORCE_NE(
(sin_dims[0] == 1 && sin_dims[1] == 1),
false,
phi::errors::InvalidArgument(
"The batch_size and num_heads of sin and cos must be 1."));
}
PADDLE_ENFORCE_NE(
(sin_dims[dims_size - 1] == head_dim &&
sin_dims[dims_size - 2] == seq_len),
false,
phi::errors::InvalidArgument("The seq_len and head_dim of sin and cos "
"must be the same as those of q."));
sin_cos_data[0] = sin->data<T>();
sin_cos_data[1] = cos->data<T>();
flag_sin_cos = true;
}
int sign = 1;
VectorizedFusedRopeKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace phi {
namespace fusion {
template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
phi::Array<const T*, 2> sin_cos_data,
bool flag_sin_cos,
int sign,
int batch_size,
int seq_len,
int num_heads,
int head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int stride = gridDim.x * blockDim.x * VecSize;
int size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize];
MPType cos_value[VecSize];
MPType result[VecSize];
T store[VecSize];
using VecType = phi::AlignedVector<T, VecSize>;
constexpr int kVectorsPerThread = VecSize / 2;
for (; index < size; index += stride) {
if (flag_sin_cos) {
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
int pos_head = index_wc % head_dim;
int index_sc = pos_seq * head_dim + pos_head;
const T* sin_input = sin_cos_data[0] + index_sc;
const T* cos_input = sin_cos_data[1] + index_sc;
sin_value[nx] = static_cast<MPType>(sin_input[0]);
cos_value[nx] = static_cast<MPType>(cos_input[0]);
}
} else {
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c));
MPType value = pos_seq * indicses;
sin_value[nx] = sin(value);
cos_value[nx] = cos(value);
}
}
#pragma unroll
for (int iter = 0; iter < 3; iter++) {
if (iter > num_inputs) break;
const T* input = ins_data[iter] + index;
VecType* out = reinterpret_cast<VecType*>(outs_data[iter] + index);
#pragma unroll
for (int nx = 0; nx < kVectorsPerThread; ++nx) {
int pr_index = nx * 2;
int ls_index = pr_index + 1;
MPType p0 = static_cast<MPType>(input[pr_index]);
MPType p1 = static_cast<MPType>(input[ls_index]);
result[pr_index] =
cos_value[pr_index] * p0 - sign * sin_value[ls_index] * p1;
result[ls_index] =
cos_value[ls_index] * p1 + sign * sin_value[pr_index] * p0;
store[pr_index] = static_cast<T>(result[pr_index]);
store[ls_index] = static_cast<T>(result[ls_index]);
}
out[0] = *(reinterpret_cast<VecType*>(store));
}
}
}
} // namespace fusion
} // namespace phi
......@@ -17,16 +17,16 @@ from paddle import _C_ops
from paddle.framework import in_dynamic_mode
def fused_rotary_position_embedding(q, k, v):
def fused_rotary_position_embedding(q, k, v, sin=None, cos=None):
r"""
Fused rotary position embedding.
Args:
q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if q must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
k (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
v (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
k (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
v (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
sin (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if sin must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2.
cos (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if cos must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2.
Returns:
out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` .
......@@ -44,6 +44,12 @@ def fused_rotary_position_embedding(q, k, v):
k = paddle.randn([1, 1, 4, 10], dtype='float16')
v = paddle.randn([1, 1, 4, 10], dtype='float16')
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v)
x = paddle.randn([1, 1, 1, 10], dtype='float16')
y = paddle.randn([1, 1, 1, 10], dtype='float16')
sin = paddle.sin(x)
cos = paddle.cos(y)
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos)
"""
if in_dynamic_mode():
return _C_ops.fused_rotary_position_embedding(q, k, v)
return _C_ops.fused_rotary_position_embedding(q, k, v, sin, cos)
......@@ -41,38 +41,44 @@ def mult_qkv(value, cos_tensor, sin_tensor):
return query
def paddle_fused_rotary_position_embedding(init_q, init_k, init_v):
q, k, v = deal_qkv(init_q, init_k, init_v)
pos_seq = paddle.arange(0, q.shape[2], 1, dtype="float32")
indices = paddle.arange(0, q.shape[3], 2, dtype="float32")
def get_sin_cos_tensor(seq_len, head_dim, sign):
pos_seq = paddle.arange(0, seq_len, 1, dtype="float32")
indices = paddle.arange(0, head_dim, 2, dtype="float32")
indices = 1 / 10000 ** (indices / q.shape[3])
indices = 1 / 10000 ** (indices / head_dim)
sinusoid_inp = pos_seq.unsqueeze(1) * indices.unsqueeze(0)
sin_sin = np.empty((q.shape[2] * q.shape[3]), dtype=np.float32)
cos_cos = np.empty((q.shape[2] * q.shape[3]), dtype=np.float32)
sin_sin = np.empty((seq_len * head_dim), dtype=np.float32)
cos_cos = np.empty((seq_len * head_dim), dtype=np.float32)
numpy_array = sinusoid_inp.numpy()
iter_array = np.nditer(numpy_array)
i = 0
for value in iter_array:
sin_sin[i * 2] = -1 * np.sin(value)
sin_sin[i * 2] = sign * np.sin(value)
cos_cos[i * 2 + 0] = np.cos(value)
sin_sin[i * 2 + 1] = np.sin(value)
cos_cos[i * 2 + 1] = np.cos(value)
i += 1
sin_tensor = paddle.reshape(
paddle.to_tensor(sin_sin, place=paddle.CPUPlace()),
[1, 1, q.shape[2], q.shape[3]],
tensor_sin = paddle.reshape(
paddle.to_tensor(sin_sin),
[1, 1, seq_len, head_dim],
)
cos_tensor = paddle.reshape(
paddle.to_tensor(cos_cos, place=paddle.CPUPlace()),
[1, 1, q.shape[2], q.shape[3]],
tensor_cos = paddle.reshape(
paddle.to_tensor(cos_cos),
[1, 1, seq_len, head_dim],
)
return tensor_sin, tensor_cos
def paddle_fused_rotary_position_embedding(init_q, init_k, init_v):
q, k, v = deal_qkv(init_q, init_k, init_v)
sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], -1)
query = mult_qkv(q, cos_tensor, sin_tensor)
value = mult_qkv(v, cos_tensor, sin_tensor)
key = mult_qkv(k, cos_tensor, sin_tensor)
......@@ -98,7 +104,7 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase):
tmp.stop_gradient = False
return tmp
def get_forward_backward(self, rope_function, seed):
def get_forward_backward(self, rope_function, seed, flag=0):
paddle.disable_static()
paddle.seed(seed)
fw = []
......@@ -106,6 +112,14 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase):
tensor_q = self.get_paddle_tensor()
tensor_k = self.get_paddle_tensor()
tensor_v = self.get_paddle_tensor()
if flag:
tensor_sin, tensor_cos = get_sin_cos_tensor(
tensor_q.shape[1], tensor_q.shape[3], 1
)
out_q, out_k, out_v = rope_function(
tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos
)
else:
out_q, out_k, out_v = rope_function(tensor_q, tensor_k, tensor_v)
fw.append(out_q)
......@@ -139,6 +153,21 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase):
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
)
def test_fused_dropout_add_sin_cos(self):
p_fw, p_bw = self.get_forward_backward(
paddle_fused_rotary_position_embedding, seed=self.seed
)
f_fw, f_bw = self.get_forward_backward(
fused_rotary_position_embedding, seed=self.seed, flag=1
)
for i in range(len(p_fw)):
np.testing.assert_allclose(
p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05
)
np.testing.assert_allclose(
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册