未验证 提交 f60c698f 编写于 作者: Y Yiqun Liu 提交者: GitHub

Fix the shape of input sin and cos for fused_rope. (#56132)

* Fix the shape of input sin and cos for fused_rope.

* Update shape in unittest.
上级 22dbceca
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h" #include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h"
namespace phi { namespace phi {
namespace fusion { namespace fusion {
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h" #include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h"
namespace phi { namespace phi {
namespace fusion { namespace fusion {
...@@ -35,13 +36,14 @@ void FusedRopeKernel(const Context& dev_ctx, ...@@ -35,13 +36,14 @@ void FusedRopeKernel(const Context& dev_ctx,
int64_t numel = q.numel(); int64_t numel = q.numel();
if (numel <= 0) return; if (numel <= 0) return;
dev_ctx.template Alloc<T>(out_q); dev_ctx.template Alloc<T>(out_q);
// small size for broadcast
// q.shape: [batch_size, seq_len, num_heads, head_dim]
auto batch_size = q.dims()[0]; auto batch_size = q.dims()[0];
auto seq_len = q.dims()[1];
auto num_heads = q.dims()[2]; auto num_heads = q.dims()[2];
auto head_dim = q.dims()[3]; auto head_dim = q.dims()[3];
auto seq_len = q.dims()[1]; PADDLE_ENFORCE_EQ(head_dim % 2,
PADDLE_ENFORCE_NE(head_dim % 2, 0,
1,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The head_dim of input must be a multiple of 2.")); "The head_dim of input must be a multiple of 2."));
...@@ -85,26 +87,37 @@ void FusedRopeKernel(const Context& dev_ctx, ...@@ -85,26 +87,37 @@ void FusedRopeKernel(const Context& dev_ctx,
PADDLE_ENFORCE_EQ(sin.get_ptr()->dims(), PADDLE_ENFORCE_EQ(sin.get_ptr()->dims(),
cos.get_ptr()->dims(), cos.get_ptr()->dims(),
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The dims of sin and cos must be the same.")); "The dims of sin and cos must be the same. But "
"recieved sin's dims is {%s}, cos's dims is {%s}.",
sin.get_ptr()->dims(),
cos.get_ptr()->dims()));
auto sin_dims = sin.get_ptr()->dims(); auto sin_dims = sin.get_ptr()->dims();
int dims_size = sin_dims.size(); int dims_size = sin_dims.size();
PADDLE_ENFORCE_NE((dims_size == 2 || dims_size == 4), PADDLE_ENFORCE_EQ(
false, (dims_size == 2 || dims_size == 4),
phi::errors::InvalidArgument( true,
"The dims of sin and cos must be 2 or 4.")); phi::errors::InvalidArgument("The dims of sin and cos is expected to "
"be 2 or 4, but recieved %d.",
dims_size));
if (dims_size == 4) { if (dims_size == 4) {
PADDLE_ENFORCE_NE( // sin.shape: [1, seq_len, 1, head_dim]
(sin_dims[0] == 1 && sin_dims[1] == 1), PADDLE_ENFORCE_EQ(
false, (sin_dims[0] == 1 && sin_dims[2] == 1),
true,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The batch_size and num_heads of sin and cos must be 1.")); "The batch_size and num_heads of sin and cos must be 1."));
} }
PADDLE_ENFORCE_NE( int sin_seq_len_dim = (dims_size) == 4 ? 1 : 0;
(sin_dims[dims_size - 1] == head_dim && PADDLE_ENFORCE_EQ((sin_dims[dims_size - 1] == head_dim &&
sin_dims[dims_size - 2] == seq_len), sin_dims[sin_seq_len_dim] == seq_len),
false, true,
phi::errors::InvalidArgument("The seq_len and head_dim of sin and cos " phi::errors::InvalidArgument(
"must be the same as those of q.")); "The seq_len and head_dim of sin and cos "
"must be the same as those of q. But recieved sin's "
"shape is {%s}, q's shape is {%s}.",
sin_dims,
q.dims()));
sin_cos_data[0] = sin->data<T>(); sin_cos_data[0] = sin->data<T>();
sin_cos_data[1] = cos->data<T>(); sin_cos_data[1] = cos->data<T>();
......
...@@ -64,27 +64,35 @@ def get_sin_cos_tensor(seq_len, head_dim, sign): ...@@ -64,27 +64,35 @@ def get_sin_cos_tensor(seq_len, head_dim, sign):
tensor_sin = paddle.reshape( tensor_sin = paddle.reshape(
paddle.to_tensor(sin_sin), paddle.to_tensor(sin_sin),
[1, 1, seq_len, head_dim], [1, seq_len, 1, head_dim],
) )
tensor_cos = paddle.reshape( tensor_cos = paddle.reshape(
paddle.to_tensor(cos_cos), paddle.to_tensor(cos_cos),
[1, 1, seq_len, head_dim], [1, seq_len, 1, head_dim],
) )
return tensor_sin, tensor_cos return tensor_sin, tensor_cos
def paddle_fused_rotary_position_embedding(init_q, init_k, init_v): def paddle_fused_rotary_position_embedding(init_q, init_k, init_v):
# permute q, k, v from [batch_size, seq_len, num_heads, head_dim]
# to [batch_size, num_heads, seq_len, head_dim]
q, k, v = deal_qkv(init_q, init_k, init_v) q, k, v = deal_qkv(init_q, init_k, init_v)
sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], -1) sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], -1)
# permute sin, cos from [1, seq_len, 1, head_dim]
# to [1, 1, seq_len, head_dim]
perm = [0, 2, 1, 3]
sin_tensor = paddle.transpose(x=sin_tensor, perm=perm)
cos_tensor = paddle.transpose(x=cos_tensor, perm=perm)
query = mult_qkv(q, cos_tensor, sin_tensor) query = mult_qkv(q, cos_tensor, sin_tensor)
value = mult_qkv(v, cos_tensor, sin_tensor) value = mult_qkv(v, cos_tensor, sin_tensor)
key = mult_qkv(k, cos_tensor, sin_tensor) key = mult_qkv(k, cos_tensor, sin_tensor)
# permute the result back to [batch_size, seq_len, num_heads, head_dim]
r_query, r_key, r_value = deal_qkv(query, key, value) r_query, r_key, r_value = deal_qkv(query, key, value)
return r_query, r_key, r_value return r_query, r_key, r_value
...@@ -94,7 +102,7 @@ def paddle_fused_rotary_position_embedding(init_q, init_k, init_v): ...@@ -94,7 +102,7 @@ def paddle_fused_rotary_position_embedding(init_q, init_k, init_v):
) )
class TestFusedRotaryPositionEmbedding(unittest.TestCase): class TestFusedRotaryPositionEmbedding(unittest.TestCase):
def setUp(self): def setUp(self):
self.shape = [1, 16, 1, 16] self.shape = [1, 8, 2, 16]
self.dtype = 'float32' self.dtype = 'float32'
self.training = True self.training = True
self.seed = 1203 self.seed = 1203
...@@ -138,7 +146,7 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase): ...@@ -138,7 +146,7 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase):
return fw, bw return fw, bw
def test_fused_dropout_add(self): def test_fused_rope(self):
p_fw, p_bw = self.get_forward_backward( p_fw, p_bw = self.get_forward_backward(
paddle_fused_rotary_position_embedding, seed=self.seed paddle_fused_rotary_position_embedding, seed=self.seed
) )
...@@ -153,7 +161,7 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase): ...@@ -153,7 +161,7 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase):
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
) )
def test_fused_dropout_add_sin_cos(self): def test_fused_rope_with_sin_cos(self):
p_fw, p_bw = self.get_forward_backward( p_fw, p_bw = self.get_forward_backward(
paddle_fused_rotary_position_embedding, seed=self.seed paddle_fused_rotary_position_embedding, seed=self.seed
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册