未验证 提交 8d181e37 编写于 作者: N niuliling123 提交者: GitHub

change index's dtype for int to int64 (#55949)

上级 4eba6478
...@@ -32,10 +32,9 @@ void FusedRopeGradKernel(const Context& dev_ctx, ...@@ -32,10 +32,9 @@ void FusedRopeGradKernel(const Context& dev_ctx,
DenseTensor* dq, DenseTensor* dq,
DenseTensor* dk, DenseTensor* dk,
DenseTensor* dv) { DenseTensor* dv) {
int numel = dout_q.numel(); int64_t numel = dout_q.numel();
if (numel <= 0) return; if (numel <= 0) return;
dev_ctx.template Alloc<T>(dq); dev_ctx.template Alloc<T>(dq);
dq->Resize(dout_q.dims());
// small size for broadcast // small size for broadcast
auto batch_size = dout_q.dims()[0]; auto batch_size = dout_q.dims()[0];
auto num_heads = dout_q.dims()[2]; auto num_heads = dout_q.dims()[2];
...@@ -51,8 +50,8 @@ void FusedRopeGradKernel(const Context& dev_ctx, ...@@ -51,8 +50,8 @@ void FusedRopeGradKernel(const Context& dev_ctx,
auto config = auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
int grid = config.block_per_grid.x; int64_t grid = config.block_per_grid.x;
int block = config.thread_per_block.x; int64_t block = config.thread_per_block.x;
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
phi::Array<T*, 3> outs_data; phi::Array<T*, 3> outs_data;
...@@ -65,7 +64,6 @@ void FusedRopeGradKernel(const Context& dev_ctx, ...@@ -65,7 +64,6 @@ void FusedRopeGradKernel(const Context& dev_ctx,
if (dout_k.get_ptr()) { if (dout_k.get_ptr()) {
dev_ctx.template Alloc<T>(dk); dev_ctx.template Alloc<T>(dk);
dk->Resize(dout_q.dims());
outs_data[1] = dk->data<T>(); outs_data[1] = dk->data<T>();
ins_data[1] = dout_k->data<T>(); ins_data[1] = dout_k->data<T>();
num_inputs++; num_inputs++;
...@@ -73,7 +71,6 @@ void FusedRopeGradKernel(const Context& dev_ctx, ...@@ -73,7 +71,6 @@ void FusedRopeGradKernel(const Context& dev_ctx,
if (dout_v.get_ptr()) { if (dout_v.get_ptr()) {
dev_ctx.template Alloc<T>(dv); dev_ctx.template Alloc<T>(dv);
dv->Resize(dout_q.dims());
outs_data[2] = dv->data<T>(); outs_data[2] = dv->data<T>();
ins_data[2] = dout_v->data<T>(); ins_data[2] = dout_v->data<T>();
num_inputs++; num_inputs++;
......
...@@ -32,10 +32,9 @@ void FusedRopeKernel(const Context& dev_ctx, ...@@ -32,10 +32,9 @@ void FusedRopeKernel(const Context& dev_ctx,
DenseTensor* out_q, DenseTensor* out_q,
DenseTensor* out_k, DenseTensor* out_k,
DenseTensor* out_v) { DenseTensor* out_v) {
int numel = q.numel(); int64_t numel = q.numel();
if (numel <= 0) return; if (numel <= 0) return;
dev_ctx.template Alloc<T>(out_q); dev_ctx.template Alloc<T>(out_q);
out_q->Resize(q.dims());
// small size for broadcast // small size for broadcast
auto batch_size = q.dims()[0]; auto batch_size = q.dims()[0];
auto num_heads = q.dims()[2]; auto num_heads = q.dims()[2];
...@@ -51,8 +50,8 @@ void FusedRopeKernel(const Context& dev_ctx, ...@@ -51,8 +50,8 @@ void FusedRopeKernel(const Context& dev_ctx,
auto config = auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
int grid = config.block_per_grid.x; int64_t grid = config.block_per_grid.x;
int block = config.thread_per_block.x; int64_t block = config.thread_per_block.x;
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
phi::Array<T*, 3> outs_data; phi::Array<T*, 3> outs_data;
...@@ -65,7 +64,6 @@ void FusedRopeKernel(const Context& dev_ctx, ...@@ -65,7 +64,6 @@ void FusedRopeKernel(const Context& dev_ctx,
if (k.get_ptr()) { if (k.get_ptr()) {
dev_ctx.template Alloc<T>(out_k); dev_ctx.template Alloc<T>(out_k);
out_k->Resize(q.dims());
ins_data[1] = k->data<T>(); ins_data[1] = k->data<T>();
outs_data[1] = out_k->data<T>(); outs_data[1] = out_k->data<T>();
num_inputs++; num_inputs++;
...@@ -73,7 +71,6 @@ void FusedRopeKernel(const Context& dev_ctx, ...@@ -73,7 +71,6 @@ void FusedRopeKernel(const Context& dev_ctx,
if (v.get_ptr()) { if (v.get_ptr()) {
dev_ctx.template Alloc<T>(out_v); dev_ctx.template Alloc<T>(out_v);
out_v->Resize(q.dims());
ins_data[2] = v->data<T>(); ins_data[2] = v->data<T>();
outs_data[2] = out_v->data<T>(); outs_data[2] = out_v->data<T>();
num_inputs++; num_inputs++;
......
...@@ -24,16 +24,20 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data, ...@@ -24,16 +24,20 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
phi::Array<const T*, 2> sin_cos_data, phi::Array<const T*, 2> sin_cos_data,
bool flag_sin_cos, bool flag_sin_cos,
int sign, int sign,
int batch_size, int64_t batch_size,
int seq_len, int64_t seq_len,
int num_heads, int64_t num_heads,
int head_dim, int64_t head_dim,
phi::Array<T*, 3> outs_data, phi::Array<T*, 3> outs_data,
int num_inputs, int num_inputs,
MPType div_c) { MPType div_c) {
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; int64_t index =
int stride = gridDim.x * blockDim.x * VecSize; (static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
int size = batch_size * seq_len * num_heads * head_dim; threadIdx.x) *
VecSize;
int64_t stride = static_cast<int64_t>(gridDim.x) *
static_cast<int64_t>(blockDim.x) * VecSize;
int64_t size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize]; MPType sin_value[VecSize];
MPType cos_value[VecSize]; MPType cos_value[VecSize];
MPType result[VecSize]; MPType result[VecSize];
...@@ -44,11 +48,11 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data, ...@@ -44,11 +48,11 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
for (; index < size; index += stride) { for (; index < size; index += stride) {
if (flag_sin_cos) { if (flag_sin_cos) {
#pragma unroll #pragma unroll
for (int nx = 0; nx < VecSize; ++nx) { for (int64_t nx = 0; nx < VecSize; ++nx) {
int index_wc = (index + nx) % (seq_len * num_heads * head_dim); int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim); int64_t pos_seq = index_wc / (num_heads * head_dim);
int pos_head = index_wc % head_dim; int64_t pos_head = index_wc % head_dim;
int index_sc = pos_seq * head_dim + pos_head; int64_t index_sc = pos_seq * head_dim + pos_head;
const T* sin_input = sin_cos_data[0] + index_sc; const T* sin_input = sin_cos_data[0] + index_sc;
const T* cos_input = sin_cos_data[1] + index_sc; const T* cos_input = sin_cos_data[1] + index_sc;
...@@ -59,8 +63,8 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data, ...@@ -59,8 +63,8 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
#pragma unroll #pragma unroll
for (int nx = 0; nx < VecSize; ++nx) { for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index // get sin_index and cos_index
int index_wc = (index + nx) % (seq_len * num_heads * head_dim); int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim); int64_t pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0); MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses = MPType indicses =
static_cast<MPType>(1) / static_cast<MPType>(1) /
......
...@@ -17,7 +17,7 @@ from paddle import _C_ops ...@@ -17,7 +17,7 @@ from paddle import _C_ops
from paddle.framework import in_dynamic_mode from paddle.framework import in_dynamic_mode
def fused_rotary_position_embedding(q, k, v, sin=None, cos=None): def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None):
r""" r"""
Fused rotary position embedding. Fused rotary position embedding.
...@@ -53,3 +53,7 @@ def fused_rotary_position_embedding(q, k, v, sin=None, cos=None): ...@@ -53,3 +53,7 @@ def fused_rotary_position_embedding(q, k, v, sin=None, cos=None):
""" """
if in_dynamic_mode(): if in_dynamic_mode():
return _C_ops.fused_rotary_position_embedding(q, k, v, sin, cos) return _C_ops.fused_rotary_position_embedding(q, k, v, sin, cos)
raise RuntimeError(
"This feature is currently supported only in dynamic mode and with CUDAPlace."
)
...@@ -168,6 +168,15 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase): ...@@ -168,6 +168,15 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase):
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
) )
def test_error(self):
paddle.enable_static()
with self.assertRaises(RuntimeError):
static_q = paddle.static.data(
name="q", shape=self.shape, dtype=self.dtype
)
fused_rotary_position_embedding(static_q, static_q, static_q)
paddle.disable_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册