未验证 提交 c089a2af 编写于 作者: T tianhaodongbd 提交者: GitHub

Add rotate_half implementation for fused_rope (#56401)

* add rotate_half in fused_rope

* add position_ids in fused_rope

* modified examples about fused_rope

* add set_device in examples
上级 be9cb946
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
support_dygraph_mode : true support_dygraph_mode : true
- backward_op : fused_rotary_position_embedding_grad - backward_op : fused_rotary_position_embedding_grad
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos) -> Tensor(out_q), Tensor(out_k), Tensor(out_v) forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor sin, Tensor cos, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad) args : (Tensor sin, Tensor cos, Tensor position_ids, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad, bool use_neox_rotary_style)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
optional : sin, cos, out_k_grad, out_v_grad, k_grad, v_grad optional : sin, cos, position_ids, out_k_grad, out_v_grad, k_grad, v_grad
infer_meta : infer_meta :
func : FusedRopeGradInferMeta func : FusedRopeGradInferMeta
kernel : kernel :
......
...@@ -149,11 +149,11 @@ ...@@ -149,11 +149,11 @@
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index
- op : fused_rotary_position_embedding - op : fused_rotary_position_embedding
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos) args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style = true)
output : Tensor(out_q), Tensor(out_k), Tensor(out_v) output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
infer_meta : infer_meta :
func : FusedRopeInferMeta func : FusedRopeInferMeta
optional : k,v,sin,cos, out_k, out_v optional : k, v, sin, cos, position_ids, out_k, out_v
kernel : kernel :
func : fused_rotary_position_embedding func : fused_rotary_position_embedding
data_type : q data_type : q
......
...@@ -1219,9 +1219,11 @@ void IndexPutGradInferMeta(const MetaTensor& x, ...@@ -1219,9 +1219,11 @@ void IndexPutGradInferMeta(const MetaTensor& x,
void FusedRopeGradInferMeta(const MetaTensor& sin, void FusedRopeGradInferMeta(const MetaTensor& sin,
const MetaTensor& cos, const MetaTensor& cos,
const MetaTensor& position_ids,
const MetaTensor& dout_q, const MetaTensor& dout_q,
const MetaTensor& dout_k, const MetaTensor& dout_k,
const MetaTensor& dout_v, const MetaTensor& dout_v,
bool use_neox_rotary_style,
MetaTensor* dq, MetaTensor* dq,
MetaTensor* dk, MetaTensor* dk,
MetaTensor* dv) { MetaTensor* dv) {
......
...@@ -186,9 +186,11 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset, ...@@ -186,9 +186,11 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
void FusedRopeGradInferMeta(const MetaTensor& sin, void FusedRopeGradInferMeta(const MetaTensor& sin,
const MetaTensor& cos, const MetaTensor& cos,
const MetaTensor& position_ids,
const MetaTensor& dout_q, const MetaTensor& dout_q,
const MetaTensor& dout_k, const MetaTensor& dout_k,
const MetaTensor& dout_v, const MetaTensor& dout_v,
bool use_neox_rotary_style,
MetaTensor* dq, MetaTensor* dq,
MetaTensor* dk, MetaTensor* dk,
MetaTensor* dv); MetaTensor* dv);
......
...@@ -4041,6 +4041,8 @@ void FusedRopeInferMeta(const MetaTensor& q, ...@@ -4041,6 +4041,8 @@ void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& v, const MetaTensor& v,
const MetaTensor& sin, const MetaTensor& sin,
const MetaTensor& cos, const MetaTensor& cos,
const MetaTensor& position_ids,
bool use_neox_rotary_style,
MetaTensor* out_q, MetaTensor* out_q,
MetaTensor* out_k, MetaTensor* out_k,
MetaTensor* out_v) { MetaTensor* out_v) {
......
...@@ -807,6 +807,8 @@ void FusedRopeInferMeta(const MetaTensor& q, ...@@ -807,6 +807,8 @@ void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& v, const MetaTensor& v,
const MetaTensor& sin, const MetaTensor& sin,
const MetaTensor& cos, const MetaTensor& cos,
const MetaTensor& position_ids,
bool use_neox_rotary_style,
MetaTensor* out_q, MetaTensor* out_q,
MetaTensor* out_k, MetaTensor* out_k,
MetaTensor* out_v); MetaTensor* out_v);
......
...@@ -27,9 +27,11 @@ template <typename T, typename Context> ...@@ -27,9 +27,11 @@ template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx, void FusedRopeGradKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& sin, const paddle::optional<DenseTensor>& sin,
const paddle::optional<DenseTensor>& cos, const paddle::optional<DenseTensor>& cos,
const paddle::optional<DenseTensor>& position_ids,
const DenseTensor& dout_q, const DenseTensor& dout_q,
const paddle::optional<DenseTensor>& dout_k, const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v, const paddle::optional<DenseTensor>& dout_v,
bool use_neox_rotary_style,
DenseTensor* dq, DenseTensor* dq,
DenseTensor* dk, DenseTensor* dk,
DenseTensor* dv) { DenseTensor* dv) {
...@@ -58,6 +60,7 @@ void FusedRopeGradKernel(const Context& dev_ctx, ...@@ -58,6 +60,7 @@ void FusedRopeGradKernel(const Context& dev_ctx,
phi::Array<T*, 3> outs_data; phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data; phi::Array<const T*, 3> ins_data;
phi::Array<const T*, 2> sin_cos_data; phi::Array<const T*, 2> sin_cos_data;
const int64_t* position_ids_data = NULL;
ins_data[0] = dout_q.data<T>(); ins_data[0] = dout_q.data<T>();
outs_data[0] = dq->data<T>(); outs_data[0] = dq->data<T>();
...@@ -86,21 +89,42 @@ void FusedRopeGradKernel(const Context& dev_ctx, ...@@ -86,21 +89,42 @@ void FusedRopeGradKernel(const Context& dev_ctx,
sin_cos_data[1] = cos->data<T>(); sin_cos_data[1] = cos->data<T>();
flag_sin_cos = true; flag_sin_cos = true;
if (position_ids.get_ptr()) {
position_ids_data = position_ids->data<int64_t>();
}
} }
int sign = -1; int sign = -1;
VectorizedFusedRopeKernel<T, MPType, vec_size> if (use_neox_rotary_style) {
<<<grid, block, 0, stream>>>(ins_data, VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
sin_cos_data, <<<grid, block, 0, stream>>>(ins_data,
flag_sin_cos, sin_cos_data,
sign, position_ids_data,
batch_size, flag_sin_cos,
seq_len, sign,
num_heads, batch_size,
head_dim, seq_len,
outs_data, num_heads,
num_inputs, head_dim,
div_c); outs_data,
num_inputs,
div_c);
} else {
VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
position_ids_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
}
} }
} // namespace fusion } // namespace fusion
......
...@@ -30,6 +30,8 @@ void FusedRopeKernel(const Context& dev_ctx, ...@@ -30,6 +30,8 @@ void FusedRopeKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& v, const paddle::optional<DenseTensor>& v,
const paddle::optional<DenseTensor>& sin, const paddle::optional<DenseTensor>& sin,
const paddle::optional<DenseTensor>& cos, const paddle::optional<DenseTensor>& cos,
const paddle::optional<DenseTensor>& position_ids,
bool use_neox_rotary_style,
DenseTensor* out_q, DenseTensor* out_q,
DenseTensor* out_k, DenseTensor* out_k,
DenseTensor* out_v) { DenseTensor* out_v) {
...@@ -59,6 +61,7 @@ void FusedRopeKernel(const Context& dev_ctx, ...@@ -59,6 +61,7 @@ void FusedRopeKernel(const Context& dev_ctx,
phi::Array<T*, 3> outs_data; phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data; phi::Array<const T*, 3> ins_data;
phi::Array<const T*, 2> sin_cos_data; phi::Array<const T*, 2> sin_cos_data;
const int64_t* position_ids_data = NULL;
ins_data[0] = q.data<T>(); ins_data[0] = q.data<T>();
outs_data[0] = out_q->data<T>(); outs_data[0] = out_q->data<T>();
...@@ -109,15 +112,52 @@ void FusedRopeKernel(const Context& dev_ctx, ...@@ -109,15 +112,52 @@ void FusedRopeKernel(const Context& dev_ctx,
"The batch_size and num_heads of sin and cos must be 1.")); "The batch_size and num_heads of sin and cos must be 1."));
} }
int sin_seq_len_dim = (dims_size) == 4 ? 1 : 0; int sin_seq_len_dim = (dims_size) == 4 ? 1 : 0;
PADDLE_ENFORCE_EQ((sin_dims[dims_size - 1] == head_dim &&
sin_dims[sin_seq_len_dim] == seq_len), if (position_ids.get_ptr()) {
true, PADDLE_ENFORCE_EQ(
phi::errors::InvalidArgument( (sin_dims[dims_size - 1] == head_dim &&
"The seq_len and head_dim of sin and cos " sin_dims[sin_seq_len_dim] >= seq_len),
"must be the same as those of q. But recieved sin's " true,
"shape is {%s}, q's shape is {%s}.", phi::errors::InvalidArgument(
sin_dims, "The seq_len of sin and cos must be greater than or equal to "
q.dims())); "this of q. The head_dim of sin and cos must be the same as this "
"of q. But recieved sin's "
"shape is {%s}, q's shape is {%s}.",
sin_dims,
q.dims()));
auto position_ids_dims = position_ids.get_ptr()->dims();
PADDLE_ENFORCE_EQ(position_ids_dims.size(),
2,
phi::errors::InvalidArgument(
"The dims of position_ids is expected to "
"be 2, but recieved %d.",
position_ids_dims.size()));
PADDLE_ENFORCE_EQ(
(position_ids_dims[0] == batch_size &&
position_ids_dims[1] == seq_len),
true,
phi::errors::InvalidArgument(
"The batch_size and seq_len of position_ids must be the same as "
"those of q. But recieved position_ids's "
"shape is {%s}, q's shape is {%s}.",
position_ids_dims,
q.dims()));
position_ids_data = position_ids->data<int64_t>();
} else {
PADDLE_ENFORCE_EQ(
(sin_dims[dims_size - 1] == head_dim &&
sin_dims[sin_seq_len_dim] == seq_len),
true,
phi::errors::InvalidArgument(
"The seq_len and head_dim of sin and cos "
"must be the same as those of q. But recieved sin's "
"shape is {%s}, q's shape is {%s}.",
sin_dims,
q.dims()));
}
sin_cos_data[0] = sin->data<T>(); sin_cos_data[0] = sin->data<T>();
sin_cos_data[1] = cos->data<T>(); sin_cos_data[1] = cos->data<T>();
...@@ -126,18 +166,35 @@ void FusedRopeKernel(const Context& dev_ctx, ...@@ -126,18 +166,35 @@ void FusedRopeKernel(const Context& dev_ctx,
} }
int sign = 1; int sign = 1;
VectorizedFusedRopeKernel<T, MPType, vec_size> if (use_neox_rotary_style) {
<<<grid, block, 0, stream>>>(ins_data, VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
sin_cos_data, <<<grid, block, 0, stream>>>(ins_data,
flag_sin_cos, sin_cos_data,
sign, position_ids_data,
batch_size, flag_sin_cos,
seq_len, sign,
num_heads, batch_size,
head_dim, seq_len,
outs_data, num_heads,
num_inputs, head_dim,
div_c); outs_data,
num_inputs,
div_c);
} else {
VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
position_ids_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
}
} }
} // namespace fusion } // namespace fusion
} // namespace phi } // namespace phi
......
...@@ -20,17 +20,71 @@ namespace phi { ...@@ -20,17 +20,71 @@ namespace phi {
namespace fusion { namespace fusion {
template <typename T, typename MPType, int VecSize = 2> template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data, __device__ void VectorizedGetSinCos(phi::Array<const T*, 2> sin_cos_data,
phi::Array<const T*, 2> sin_cos_data, const int64_t* position_ids_data,
bool flag_sin_cos, bool flag_sin_cos,
int sign, int64_t index,
int64_t batch_size, int64_t seq_len,
int64_t seq_len, int64_t num_heads,
int64_t num_heads, int64_t head_dim,
int64_t head_dim, MPType* out_sin,
phi::Array<T*, 3> outs_data, MPType* out_cos,
int num_inputs, MPType div_c) {
MPType div_c) { MPType* sin_value = out_sin;
MPType* cos_value = out_cos;
if (flag_sin_cos) {
#pragma unroll
for (int64_t nx = 0; nx < VecSize; ++nx) {
int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int64_t pos_seq_ori = index_wc / (num_heads * head_dim);
int64_t pos_seq;
if (position_ids_data) {
int64_t pos_bs = (index + nx) / (seq_len * num_heads * head_dim);
int64_t index_ids = pos_bs * seq_len + pos_seq_ori;
pos_seq = position_ids_data[index_ids];
} else {
pos_seq = pos_seq_ori;
}
int64_t pos_head = index_wc % head_dim;
int64_t index_sc = pos_seq * head_dim + pos_head;
const T* sin_input = sin_cos_data[0] + index_sc;
const T* cos_input = sin_cos_data[1] + index_sc;
sin_value[nx] = static_cast<MPType>(sin_input[0]);
cos_value[nx] = static_cast<MPType>(cos_input[0]);
}
} else {
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int64_t pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c));
MPType value = pos_seq * indicses;
sin_value[nx] = sin(value);
cos_value[nx] = cos(value);
}
}
}
template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeWithRotateEveryTwoKernel(
phi::Array<const T*, 3> ins_data,
phi::Array<const T*, 2> sin_cos_data,
const int64_t* position_ids_data,
bool flag_sin_cos,
int sign,
int64_t batch_size,
int64_t seq_len,
int64_t num_heads,
int64_t head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int64_t index = int64_t index =
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) + (static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
threadIdx.x) * threadIdx.x) *
...@@ -46,34 +100,16 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data, ...@@ -46,34 +100,16 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
constexpr int kVectorsPerThread = VecSize / 2; constexpr int kVectorsPerThread = VecSize / 2;
for (; index < size; index += stride) { for (; index < size; index += stride) {
if (flag_sin_cos) { VectorizedGetSinCos(sin_cos_data,
#pragma unroll position_ids_data,
for (int64_t nx = 0; nx < VecSize; ++nx) { flag_sin_cos,
int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); index,
int64_t pos_seq = index_wc / (num_heads * head_dim); seq_len,
int64_t pos_head = index_wc % head_dim; num_heads,
int64_t index_sc = pos_seq * head_dim + pos_head; head_dim,
const T* sin_input = sin_cos_data[0] + index_sc; sin_value,
const T* cos_input = sin_cos_data[1] + index_sc; cos_value,
div_c);
sin_value[nx] = static_cast<MPType>(sin_input[0]);
cos_value[nx] = static_cast<MPType>(cos_input[0]);
}
} else {
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int64_t pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c));
MPType value = pos_seq * indicses;
sin_value[nx] = sin(value);
cos_value[nx] = cos(value);
}
}
#pragma unroll #pragma unroll
for (int iter = 0; iter < 3; iter++) { for (int iter = 0; iter < 3; iter++) {
...@@ -102,5 +138,74 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data, ...@@ -102,5 +138,74 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
} }
} }
template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeWithRotateHalfKernel(
phi::Array<const T*, 3> ins_data,
phi::Array<const T*, 2> sin_cos_data,
const int64_t* position_ids_data,
bool flag_sin_cos,
int sign,
int64_t batch_size,
int64_t seq_len,
int64_t num_heads,
int64_t head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int64_t index =
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
threadIdx.x) *
VecSize;
int64_t stride = static_cast<int64_t>(gridDim.x) *
static_cast<int64_t>(blockDim.x) * VecSize;
int64_t size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize];
MPType cos_value[VecSize];
MPType result[VecSize];
T store[VecSize];
using VecType = phi::AlignedVector<T, VecSize>;
constexpr int kVectorsPerThread = VecSize / 2;
for (; index < size; index += stride) {
VectorizedGetSinCos(sin_cos_data,
position_ids_data,
flag_sin_cos,
index,
seq_len,
num_heads,
head_dim,
sin_value,
cos_value,
div_c);
// use rotate_half mode
int stride_r = head_dim / 2;
#pragma unroll
for (int iter = 0; iter < 3; iter++) {
if (iter > num_inputs) break;
// get value_index and rotate_half_index
int index_v = index;
int index_r = (index % head_dim) < stride_r ? (index + stride_r)
: (index - stride_r);
MPType sign_r = (index % head_dim) < stride_r ? static_cast<MPType>(-1)
: static_cast<MPType>(1);
const T* input_v = ins_data[iter] + index_v;
const T* input_r = ins_data[iter] + index_r;
VecType* out = reinterpret_cast<VecType*>(outs_data[iter] + index);
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
MPType p0 = static_cast<MPType>(input_v[nx]);
MPType p1 = static_cast<MPType>(input_r[nx]);
result[nx] = cos_value[nx] * p0 + sign * sign_r * sin_value[nx] * p1;
store[nx] = static_cast<T>(result[nx]);
}
out[0] = *(reinterpret_cast<VecType*>(store));
}
}
}
} // namespace fusion } // namespace fusion
} // namespace phi } // namespace phi
...@@ -17,16 +17,26 @@ from paddle import _C_ops ...@@ -17,16 +17,26 @@ from paddle import _C_ops
from paddle.framework import in_dynamic_mode from paddle.framework import in_dynamic_mode
def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None): def fused_rotary_position_embedding(
q,
k=None,
v=None,
sin=None,
cos=None,
position_ids=None,
use_neox_rotary_style=True,
):
r""" r"""
Fused rotary position embedding. Fused rotary position embedding.
Args: Args:
q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if q must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of q must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
k (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. k (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
v (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. v (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
sin (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if sin must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. sin (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of sin must be [seq_len, head_dim] or [1, seq_len, 1, head_dim] and head_dim must be a multiple of 2.
cos (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if cos must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. cos (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of cos must be [seq_len, head_dim] or [1, seq_len, 1, head_dim] and head_dim must be a multiple of 2.
position_ids (Tensor, optional): The input tensor. The data type is int64. The shape of position_ids must be [batch_size, seq_len].
use_neox_rotary_style(optional|bool): When the use_neox_rotary_style is True, every two adjacent numbers are calculated. When the use_neox_rotary_style is False, the numbers corresponding to the positions of the front half and back half segments are calculated. Default True.
Returns: Returns:
out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` . out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` .
...@@ -36,23 +46,51 @@ def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None): ...@@ -36,23 +46,51 @@ def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None):
.. code-block:: python .. code-block:: python
# required: gpu >>> # required: gpu
import paddle >>> # doctest: +REQUIRES(env:GPU)
from paddle.incubate.nn.functional import fused_rotary_position_embedding >>> import paddle
>>> from paddle.incubate.nn.functional import fused_rotary_position_embedding
q = paddle.randn([1, 1, 4, 10], dtype='float16') >>> paddle.device.set_device('gpu')
k = paddle.randn([1, 1, 4, 10], dtype='float16')
v = paddle.randn([1, 1, 4, 10], dtype='float16')
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v)
x = paddle.randn([1, 1, 1, 10], dtype='float16') >>> # batch_size = 2
y = paddle.randn([1, 1, 1, 10], dtype='float16') >>> # seq_len = 2
sin = paddle.sin(x) >>> # num_heads = 2
cos = paddle.cos(y) >>> # head_dim = 2
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos)
>>> paddle.seed(1204)
>>> # q, k, v: [batch_size, seq_len, num_heads, head_dim]
>>> q = paddle.randn([2, 2, 2, 2], dtype='float16')
>>> k = paddle.randn([2, 2, 2, 2], dtype='float16')
>>> v = paddle.randn([2, 2, 2, 2], dtype='float16')
>>> # sin, cos: [1, seq_len, 1, head_dim]
>>> x = paddle.randn([1, 2, 1, 2], dtype='float16')
>>> y = paddle.randn([1, 2, 1, 2], dtype='float16')
>>> sin = paddle.sin(x)
>>> cos = paddle.cos(y)
>>> # position_ids: [batch_size, seq_len]
>>> position_ids = paddle.randint(high=2, shape=[2, 2], dtype='int64')
>>> # out_q, out_k, out_v: [batch_size, seq_len, num_heads, head_dim]
>>> out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos, position_ids=position_ids, use_neox_rotary_style=False)
>>> print(out_q)
Tensor(shape=[2, 2, 2, 2], dtype=float16, place=Place(gpu:0), stop_gradient=True,
[[[[-0.54931641, 0.64990234],
[-1.08691406, 1.18261719]],
[[ 0.57812500, 0.11749268],
[-0.63281250, 0.15551758]]],
[[[-0.77050781, 0.07733154],
[-0.73730469, -0.16735840]],
[[ 0.07116699, -0.90966797],
[-0.03628540, -0.20202637]]]])
""" """
if in_dynamic_mode(): if in_dynamic_mode():
return _C_ops.fused_rotary_position_embedding(q, k, v, sin, cos) return _C_ops.fused_rotary_position_embedding(
q, k, v, sin, cos, position_ids, use_neox_rotary_style
)
raise RuntimeError( raise RuntimeError(
"This feature is currently supported only in dynamic mode and with CUDAPlace." "This feature is currently supported only in dynamic mode and with CUDAPlace."
......
...@@ -41,6 +41,24 @@ def mult_qkv(value, cos_tensor, sin_tensor): ...@@ -41,6 +41,24 @@ def mult_qkv(value, cos_tensor, sin_tensor):
return query return query
def mult_qkv_rotate_half(value, cos_tensor, sin_tensor):
rotate_half_q = paddle.reshape(
paddle.concat(
[
-value[..., value.shape[-1] // 2 :],
value[..., : value.shape[-1] // 2],
],
axis=-1,
),
paddle.shape(value),
)
query = paddle.add(
paddle.multiply(value, cos_tensor),
paddle.multiply(rotate_half_q, sin_tensor),
)
return query
def get_sin_cos_tensor(seq_len, head_dim, sign): def get_sin_cos_tensor(seq_len, head_dim, sign):
pos_seq = paddle.arange(0, seq_len, 1, dtype="float32") pos_seq = paddle.arange(0, seq_len, 1, dtype="float32")
indices = paddle.arange(0, head_dim, 2, dtype="float32") indices = paddle.arange(0, head_dim, 2, dtype="float32")
...@@ -74,22 +92,38 @@ def get_sin_cos_tensor(seq_len, head_dim, sign): ...@@ -74,22 +92,38 @@ def get_sin_cos_tensor(seq_len, head_dim, sign):
return tensor_sin, tensor_cos return tensor_sin, tensor_cos
def paddle_fused_rotary_position_embedding(init_q, init_k, init_v): def paddle_fused_rotary_position_embedding(
init_q, init_k, init_v, position_ids=None, use_neox_rotary_style=True
):
# permute q, k, v from [batch_size, seq_len, num_heads, head_dim] # permute q, k, v from [batch_size, seq_len, num_heads, head_dim]
# to [batch_size, num_heads, seq_len, head_dim] # to [batch_size, num_heads, seq_len, head_dim]
q, k, v = deal_qkv(init_q, init_k, init_v) q, k, v = deal_qkv(init_q, init_k, init_v)
sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], -1) sign = -1 if use_neox_rotary_style else 1
sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], sign)
if position_ids is not None:
sin_tensor = sin_tensor.squeeze(axis=[0, 2]) # [seq_len, dim]
cos_tensor = cos_tensor.squeeze(axis=[0, 2]) # [seq_len, dim]
sin_tensor = sin_tensor[position_ids].unsqueeze(
2
) # [bs, seq_len, 1, dim]
cos_tensor = cos_tensor[position_ids].unsqueeze(
2
) # [bs, seq_len, 1, dim]
# permute sin, cos from [1, seq_len, 1, head_dim]
# to [1, 1, seq_len, head_dim]
perm = [0, 2, 1, 3] perm = [0, 2, 1, 3]
sin_tensor = paddle.transpose(x=sin_tensor, perm=perm) sin_tensor = paddle.transpose(x=sin_tensor, perm=perm)
cos_tensor = paddle.transpose(x=cos_tensor, perm=perm) cos_tensor = paddle.transpose(x=cos_tensor, perm=perm)
query = mult_qkv(q, cos_tensor, sin_tensor) if use_neox_rotary_style:
value = mult_qkv(v, cos_tensor, sin_tensor) query = mult_qkv(q, cos_tensor, sin_tensor)
key = mult_qkv(k, cos_tensor, sin_tensor) value = mult_qkv(v, cos_tensor, sin_tensor)
key = mult_qkv(k, cos_tensor, sin_tensor)
else:
query = mult_qkv_rotate_half(q, cos_tensor, sin_tensor)
value = mult_qkv_rotate_half(v, cos_tensor, sin_tensor)
key = mult_qkv_rotate_half(k, cos_tensor, sin_tensor)
# permute the result back to [batch_size, seq_len, num_heads, head_dim] # permute the result back to [batch_size, seq_len, num_heads, head_dim]
r_query, r_key, r_value = deal_qkv(query, key, value) r_query, r_key, r_value = deal_qkv(query, key, value)
...@@ -102,7 +136,7 @@ def paddle_fused_rotary_position_embedding(init_q, init_k, init_v): ...@@ -102,7 +136,7 @@ def paddle_fused_rotary_position_embedding(init_q, init_k, init_v):
) )
class TestFusedRotaryPositionEmbedding(unittest.TestCase): class TestFusedRotaryPositionEmbedding(unittest.TestCase):
def setUp(self): def setUp(self):
self.shape = [1, 8, 2, 16] self.shape = [2, 8, 2, 16]
self.dtype = 'float32' self.dtype = 'float32'
self.training = True self.training = True
self.seed = 1203 self.seed = 1203
...@@ -112,7 +146,14 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase): ...@@ -112,7 +146,14 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase):
tmp.stop_gradient = False tmp.stop_gradient = False
return tmp return tmp
def get_forward_backward(self, rope_function, seed, flag=0): def get_forward_backward(
self,
rope_function,
seed,
flag=False,
use_neox_rotary_style=True,
position_ids=None,
):
paddle.disable_static() paddle.disable_static()
paddle.seed(seed) paddle.seed(seed)
fw = [] fw = []
...@@ -120,15 +161,45 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase): ...@@ -120,15 +161,45 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase):
tensor_q = self.get_paddle_tensor() tensor_q = self.get_paddle_tensor()
tensor_k = self.get_paddle_tensor() tensor_k = self.get_paddle_tensor()
tensor_v = self.get_paddle_tensor() tensor_v = self.get_paddle_tensor()
if flag: if use_neox_rotary_style:
tensor_sin, tensor_cos = get_sin_cos_tensor( if flag:
tensor_q.shape[1], tensor_q.shape[3], 1 tensor_sin, tensor_cos = get_sin_cos_tensor(
) tensor_q.shape[1], tensor_q.shape[3], 1
out_q, out_k, out_v = rope_function( )
tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos out_q, out_k, out_v = rope_function(
) tensor_q,
tensor_k,
tensor_v,
tensor_sin,
tensor_cos,
position_ids=position_ids,
)
else:
out_q, out_k, out_v = rope_function(
tensor_q, tensor_k, tensor_v, position_ids=position_ids
)
else: else:
out_q, out_k, out_v = rope_function(tensor_q, tensor_k, tensor_v) if flag:
tensor_sin, tensor_cos = get_sin_cos_tensor(
tensor_q.shape[1], tensor_q.shape[3], 1
)
out_q, out_k, out_v = rope_function(
tensor_q,
tensor_k,
tensor_v,
tensor_sin,
tensor_cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
else:
out_q, out_k, out_v = rope_function(
tensor_q,
tensor_k,
tensor_v,
position_ids=position_ids,
use_neox_rotary_style=False,
)
fw.append(out_q) fw.append(out_q)
fw.append(out_k) fw.append(out_k)
...@@ -166,7 +237,49 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase): ...@@ -166,7 +237,49 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase):
paddle_fused_rotary_position_embedding, seed=self.seed paddle_fused_rotary_position_embedding, seed=self.seed
) )
f_fw, f_bw = self.get_forward_backward( f_fw, f_bw = self.get_forward_backward(
fused_rotary_position_embedding, seed=self.seed, flag=1 fused_rotary_position_embedding, seed=self.seed, flag=True
)
for i in range(len(p_fw)):
np.testing.assert_allclose(
p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05
)
np.testing.assert_allclose(
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
)
def test_fused_rope_rotate_half(self):
p_fw, p_bw = self.get_forward_backward(
paddle_fused_rotary_position_embedding,
seed=self.seed,
use_neox_rotary_style=False,
)
f_fw, f_bw = self.get_forward_backward(
fused_rotary_position_embedding,
seed=self.seed,
use_neox_rotary_style=False,
)
for i in range(len(p_fw)):
np.testing.assert_allclose(
p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05
)
np.testing.assert_allclose(
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
)
def test_fused_rope_position_ids(self):
position_ids = paddle.to_tensor(
[[7, 5, 4, 6, 3, 1, 2, 0], [3, 1, 4, 0, 7, 6, 5, 2]]
)
p_fw, p_bw = self.get_forward_backward(
paddle_fused_rotary_position_embedding,
seed=self.seed,
position_ids=position_ids,
)
f_fw, f_bw = self.get_forward_backward(
fused_rotary_position_embedding,
seed=self.seed,
flag=True,
position_ids=position_ids,
) )
for i in range(len(p_fw)): for i in range(len(p_fw)):
np.testing.assert_allclose( np.testing.assert_allclose(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册