diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu index 4784acf6467435a6a80f2eb25d6420969c2bcbef..442317eb53d98062d0cbc7fc8d0b9b743be940b2 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu @@ -19,6 +19,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h" + namespace phi { namespace fusion { diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu index f837793860a70e3d54de7a636cc0935368ecf402..f6dcbc2a9038f06c11e4509edac047efeb867103 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu @@ -19,6 +19,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h" + namespace phi { namespace fusion { @@ -35,13 +36,14 @@ void FusedRopeKernel(const Context& dev_ctx, int64_t numel = q.numel(); if (numel <= 0) return; dev_ctx.template Alloc(out_q); - // small size for broadcast + + // q.shape: [batch_size, seq_len, num_heads, head_dim] auto batch_size = q.dims()[0]; + auto seq_len = q.dims()[1]; auto num_heads = q.dims()[2]; auto head_dim = q.dims()[3]; - auto seq_len = q.dims()[1]; - PADDLE_ENFORCE_NE(head_dim % 2, - 1, + PADDLE_ENFORCE_EQ(head_dim % 2, + 0, phi::errors::InvalidArgument( "The head_dim of input must be a multiple of 2.")); @@ -85,26 +87,37 @@ void FusedRopeKernel(const Context& dev_ctx, PADDLE_ENFORCE_EQ(sin.get_ptr()->dims(), cos.get_ptr()->dims(), phi::errors::InvalidArgument( - "The dims of sin and cos must be the same.")); + "The dims of sin and cos must be the same. But " + "recieved sin's dims is {%s}, cos's dims is {%s}.", + sin.get_ptr()->dims(), + cos.get_ptr()->dims())); + auto sin_dims = sin.get_ptr()->dims(); int dims_size = sin_dims.size(); - PADDLE_ENFORCE_NE((dims_size == 2 || dims_size == 4), - false, - phi::errors::InvalidArgument( - "The dims of sin and cos must be 2 or 4.")); + PADDLE_ENFORCE_EQ( + (dims_size == 2 || dims_size == 4), + true, + phi::errors::InvalidArgument("The dims of sin and cos is expected to " + "be 2 or 4, but recieved %d.", + dims_size)); if (dims_size == 4) { - PADDLE_ENFORCE_NE( - (sin_dims[0] == 1 && sin_dims[1] == 1), - false, + // sin.shape: [1, seq_len, 1, head_dim] + PADDLE_ENFORCE_EQ( + (sin_dims[0] == 1 && sin_dims[2] == 1), + true, phi::errors::InvalidArgument( "The batch_size and num_heads of sin and cos must be 1.")); } - PADDLE_ENFORCE_NE( - (sin_dims[dims_size - 1] == head_dim && - sin_dims[dims_size - 2] == seq_len), - false, - phi::errors::InvalidArgument("The seq_len and head_dim of sin and cos " - "must be the same as those of q.")); + int sin_seq_len_dim = (dims_size) == 4 ? 1 : 0; + PADDLE_ENFORCE_EQ((sin_dims[dims_size - 1] == head_dim && + sin_dims[sin_seq_len_dim] == seq_len), + true, + phi::errors::InvalidArgument( + "The seq_len and head_dim of sin and cos " + "must be the same as those of q. But recieved sin's " + "shape is {%s}, q's shape is {%s}.", + sin_dims, + q.dims())); sin_cos_data[0] = sin->data(); sin_cos_data[1] = cos->data(); diff --git a/test/legacy_test/test_fused_rotary_position_embedding.py b/test/legacy_test/test_fused_rotary_position_embedding.py index 737f2850d96cd032cf6b2db6a8c9a734cf13ec9e..9842fbf1f4ee8c4b0f1140f8928a0da4fb609f1f 100644 --- a/test/legacy_test/test_fused_rotary_position_embedding.py +++ b/test/legacy_test/test_fused_rotary_position_embedding.py @@ -64,27 +64,35 @@ def get_sin_cos_tensor(seq_len, head_dim, sign): tensor_sin = paddle.reshape( paddle.to_tensor(sin_sin), - [1, 1, seq_len, head_dim], + [1, seq_len, 1, head_dim], ) tensor_cos = paddle.reshape( paddle.to_tensor(cos_cos), - [1, 1, seq_len, head_dim], + [1, seq_len, 1, head_dim], ) return tensor_sin, tensor_cos def paddle_fused_rotary_position_embedding(init_q, init_k, init_v): + # permute q, k, v from [batch_size, seq_len, num_heads, head_dim] + # to [batch_size, num_heads, seq_len, head_dim] q, k, v = deal_qkv(init_q, init_k, init_v) sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], -1) + # permute sin, cos from [1, seq_len, 1, head_dim] + # to [1, 1, seq_len, head_dim] + perm = [0, 2, 1, 3] + sin_tensor = paddle.transpose(x=sin_tensor, perm=perm) + cos_tensor = paddle.transpose(x=cos_tensor, perm=perm) + query = mult_qkv(q, cos_tensor, sin_tensor) value = mult_qkv(v, cos_tensor, sin_tensor) key = mult_qkv(k, cos_tensor, sin_tensor) + # permute the result back to [batch_size, seq_len, num_heads, head_dim] r_query, r_key, r_value = deal_qkv(query, key, value) - return r_query, r_key, r_value @@ -94,7 +102,7 @@ def paddle_fused_rotary_position_embedding(init_q, init_k, init_v): ) class TestFusedRotaryPositionEmbedding(unittest.TestCase): def setUp(self): - self.shape = [1, 16, 1, 16] + self.shape = [1, 8, 2, 16] self.dtype = 'float32' self.training = True self.seed = 1203 @@ -138,7 +146,7 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase): return fw, bw - def test_fused_dropout_add(self): + def test_fused_rope(self): p_fw, p_bw = self.get_forward_backward( paddle_fused_rotary_position_embedding, seed=self.seed ) @@ -153,7 +161,7 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase): p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 ) - def test_fused_dropout_add_sin_cos(self): + def test_fused_rope_with_sin_cos(self): p_fw, p_bw = self.get_forward_backward( paddle_fused_rotary_position_embedding, seed=self.seed )