diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu index 314f1d267a404a029dd5e759345a8fe68a27f546..4784acf6467435a6a80f2eb25d6420969c2bcbef 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu @@ -32,10 +32,9 @@ void FusedRopeGradKernel(const Context& dev_ctx, DenseTensor* dq, DenseTensor* dk, DenseTensor* dv) { - int numel = dout_q.numel(); + int64_t numel = dout_q.numel(); if (numel <= 0) return; dev_ctx.template Alloc(dq); - dq->Resize(dout_q.dims()); // small size for broadcast auto batch_size = dout_q.dims()[0]; auto num_heads = dout_q.dims()[2]; @@ -51,8 +50,8 @@ void FusedRopeGradKernel(const Context& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); - int grid = config.block_per_grid.x; - int block = config.thread_per_block.x; + int64_t grid = config.block_per_grid.x; + int64_t block = config.thread_per_block.x; auto stream = dev_ctx.stream(); phi::Array outs_data; @@ -65,7 +64,6 @@ void FusedRopeGradKernel(const Context& dev_ctx, if (dout_k.get_ptr()) { dev_ctx.template Alloc(dk); - dk->Resize(dout_q.dims()); outs_data[1] = dk->data(); ins_data[1] = dout_k->data(); num_inputs++; @@ -73,7 +71,6 @@ void FusedRopeGradKernel(const Context& dev_ctx, if (dout_v.get_ptr()) { dev_ctx.template Alloc(dv); - dv->Resize(dout_q.dims()); outs_data[2] = dv->data(); ins_data[2] = dout_v->data(); num_inputs++; diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu index e0f20cc80d21478664ef00355dc7eda8aa9ee395..f837793860a70e3d54de7a636cc0935368ecf402 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu @@ -32,10 +32,9 @@ void FusedRopeKernel(const Context& dev_ctx, DenseTensor* out_q, DenseTensor* out_k, DenseTensor* out_v) { - int numel = q.numel(); + int64_t numel = q.numel(); if (numel <= 0) return; dev_ctx.template Alloc(out_q); - out_q->Resize(q.dims()); // small size for broadcast auto batch_size = q.dims()[0]; auto num_heads = q.dims()[2]; @@ -51,8 +50,8 @@ void FusedRopeKernel(const Context& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); - int grid = config.block_per_grid.x; - int block = config.thread_per_block.x; + int64_t grid = config.block_per_grid.x; + int64_t block = config.thread_per_block.x; auto stream = dev_ctx.stream(); phi::Array outs_data; @@ -65,7 +64,6 @@ void FusedRopeKernel(const Context& dev_ctx, if (k.get_ptr()) { dev_ctx.template Alloc(out_k); - out_k->Resize(q.dims()); ins_data[1] = k->data(); outs_data[1] = out_k->data(); num_inputs++; @@ -73,7 +71,6 @@ void FusedRopeKernel(const Context& dev_ctx, if (v.get_ptr()) { dev_ctx.template Alloc(out_v); - out_v->Resize(q.dims()); ins_data[2] = v->data(); outs_data[2] = out_v->data(); num_inputs++; diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h index b24db8ef09dfa389bf0ed39efe13393bbb44e03c..54ffba19e60c0e04761187737f13039c005e1134 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h @@ -24,16 +24,20 @@ __global__ void VectorizedFusedRopeKernel(phi::Array ins_data, phi::Array sin_cos_data, bool flag_sin_cos, int sign, - int batch_size, - int seq_len, - int num_heads, - int head_dim, + int64_t batch_size, + int64_t seq_len, + int64_t num_heads, + int64_t head_dim, phi::Array outs_data, int num_inputs, MPType div_c) { - int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; - int stride = gridDim.x * blockDim.x * VecSize; - int size = batch_size * seq_len * num_heads * head_dim; + int64_t index = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + threadIdx.x) * + VecSize; + int64_t stride = static_cast(gridDim.x) * + static_cast(blockDim.x) * VecSize; + int64_t size = batch_size * seq_len * num_heads * head_dim; MPType sin_value[VecSize]; MPType cos_value[VecSize]; MPType result[VecSize]; @@ -44,11 +48,11 @@ __global__ void VectorizedFusedRopeKernel(phi::Array ins_data, for (; index < size; index += stride) { if (flag_sin_cos) { #pragma unroll - for (int nx = 0; nx < VecSize; ++nx) { - int index_wc = (index + nx) % (seq_len * num_heads * head_dim); - int pos_seq = index_wc / (num_heads * head_dim); - int pos_head = index_wc % head_dim; - int index_sc = pos_seq * head_dim + pos_head; + for (int64_t nx = 0; nx < VecSize; ++nx) { + int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); + int64_t pos_seq = index_wc / (num_heads * head_dim); + int64_t pos_head = index_wc % head_dim; + int64_t index_sc = pos_seq * head_dim + pos_head; const T* sin_input = sin_cos_data[0] + index_sc; const T* cos_input = sin_cos_data[1] + index_sc; @@ -59,8 +63,8 @@ __global__ void VectorizedFusedRopeKernel(phi::Array ins_data, #pragma unroll for (int nx = 0; nx < VecSize; ++nx) { // get sin_index and cos_index - int index_wc = (index + nx) % (seq_len * num_heads * head_dim); - int pos_seq = index_wc / (num_heads * head_dim); + int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); + int64_t pos_seq = index_wc / (num_heads * head_dim); MPType idx = static_cast((index_wc % head_dim) / 2 * 2.0); MPType indicses = static_cast(1) / diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index 0b29f785e21ba974cc96e0104576c4e084f0108d..e05ae63f07807e548b58c38e8419cbda3b12102d 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -17,7 +17,7 @@ from paddle import _C_ops from paddle.framework import in_dynamic_mode -def fused_rotary_position_embedding(q, k, v, sin=None, cos=None): +def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None): r""" Fused rotary position embedding. @@ -53,3 +53,7 @@ def fused_rotary_position_embedding(q, k, v, sin=None, cos=None): """ if in_dynamic_mode(): return _C_ops.fused_rotary_position_embedding(q, k, v, sin, cos) + + raise RuntimeError( + "This feature is currently supported only in dynamic mode and with CUDAPlace." + ) diff --git a/test/legacy_test/test_fused_rotary_position_embedding.py b/test/legacy_test/test_fused_rotary_position_embedding.py index 486274cbf976acbf6c72554724ae4014a162178c..737f2850d96cd032cf6b2db6a8c9a734cf13ec9e 100644 --- a/test/legacy_test/test_fused_rotary_position_embedding.py +++ b/test/legacy_test/test_fused_rotary_position_embedding.py @@ -168,6 +168,15 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase): p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 ) + def test_error(self): + paddle.enable_static() + with self.assertRaises(RuntimeError): + static_q = paddle.static.data( + name="q", shape=self.shape, dtype=self.dtype + ) + fused_rotary_position_embedding(static_q, static_q, static_q) + paddle.disable_static() + if __name__ == '__main__': unittest.main()