未验证 提交 f60c698f 编写于 作者: Y Yiqun Liu 提交者: GitHub

Fix the shape of input sin and cos for fused_rope. (#56132)

* Fix the shape of input sin and cos for fused_rope.

* Update shape in unittest.
上级 22dbceca
......@@ -19,6 +19,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h"
namespace phi {
namespace fusion {
......
......@@ -19,6 +19,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h"
namespace phi {
namespace fusion {
......@@ -35,13 +36,14 @@ void FusedRopeKernel(const Context& dev_ctx,
int64_t numel = q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(out_q);
// small size for broadcast
// q.shape: [batch_size, seq_len, num_heads, head_dim]
auto batch_size = q.dims()[0];
auto seq_len = q.dims()[1];
auto num_heads = q.dims()[2];
auto head_dim = q.dims()[3];
auto seq_len = q.dims()[1];
PADDLE_ENFORCE_NE(head_dim % 2,
1,
PADDLE_ENFORCE_EQ(head_dim % 2,
0,
phi::errors::InvalidArgument(
"The head_dim of input must be a multiple of 2."));
......@@ -85,26 +87,37 @@ void FusedRopeKernel(const Context& dev_ctx,
PADDLE_ENFORCE_EQ(sin.get_ptr()->dims(),
cos.get_ptr()->dims(),
phi::errors::InvalidArgument(
"The dims of sin and cos must be the same."));
"The dims of sin and cos must be the same. But "
"recieved sin's dims is {%s}, cos's dims is {%s}.",
sin.get_ptr()->dims(),
cos.get_ptr()->dims()));
auto sin_dims = sin.get_ptr()->dims();
int dims_size = sin_dims.size();
PADDLE_ENFORCE_NE((dims_size == 2 || dims_size == 4),
false,
phi::errors::InvalidArgument(
"The dims of sin and cos must be 2 or 4."));
PADDLE_ENFORCE_EQ(
(dims_size == 2 || dims_size == 4),
true,
phi::errors::InvalidArgument("The dims of sin and cos is expected to "
"be 2 or 4, but recieved %d.",
dims_size));
if (dims_size == 4) {
PADDLE_ENFORCE_NE(
(sin_dims[0] == 1 && sin_dims[1] == 1),
false,
// sin.shape: [1, seq_len, 1, head_dim]
PADDLE_ENFORCE_EQ(
(sin_dims[0] == 1 && sin_dims[2] == 1),
true,
phi::errors::InvalidArgument(
"The batch_size and num_heads of sin and cos must be 1."));
}
PADDLE_ENFORCE_NE(
(sin_dims[dims_size - 1] == head_dim &&
sin_dims[dims_size - 2] == seq_len),
false,
phi::errors::InvalidArgument("The seq_len and head_dim of sin and cos "
"must be the same as those of q."));
int sin_seq_len_dim = (dims_size) == 4 ? 1 : 0;
PADDLE_ENFORCE_EQ((sin_dims[dims_size - 1] == head_dim &&
sin_dims[sin_seq_len_dim] == seq_len),
true,
phi::errors::InvalidArgument(
"The seq_len and head_dim of sin and cos "
"must be the same as those of q. But recieved sin's "
"shape is {%s}, q's shape is {%s}.",
sin_dims,
q.dims()));
sin_cos_data[0] = sin->data<T>();
sin_cos_data[1] = cos->data<T>();
......
......@@ -64,27 +64,35 @@ def get_sin_cos_tensor(seq_len, head_dim, sign):
tensor_sin = paddle.reshape(
paddle.to_tensor(sin_sin),
[1, 1, seq_len, head_dim],
[1, seq_len, 1, head_dim],
)
tensor_cos = paddle.reshape(
paddle.to_tensor(cos_cos),
[1, 1, seq_len, head_dim],
[1, seq_len, 1, head_dim],
)
return tensor_sin, tensor_cos
def paddle_fused_rotary_position_embedding(init_q, init_k, init_v):
# permute q, k, v from [batch_size, seq_len, num_heads, head_dim]
# to [batch_size, num_heads, seq_len, head_dim]
q, k, v = deal_qkv(init_q, init_k, init_v)
sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], -1)
# permute sin, cos from [1, seq_len, 1, head_dim]
# to [1, 1, seq_len, head_dim]
perm = [0, 2, 1, 3]
sin_tensor = paddle.transpose(x=sin_tensor, perm=perm)
cos_tensor = paddle.transpose(x=cos_tensor, perm=perm)
query = mult_qkv(q, cos_tensor, sin_tensor)
value = mult_qkv(v, cos_tensor, sin_tensor)
key = mult_qkv(k, cos_tensor, sin_tensor)
# permute the result back to [batch_size, seq_len, num_heads, head_dim]
r_query, r_key, r_value = deal_qkv(query, key, value)
return r_query, r_key, r_value
......@@ -94,7 +102,7 @@ def paddle_fused_rotary_position_embedding(init_q, init_k, init_v):
)
class TestFusedRotaryPositionEmbedding(unittest.TestCase):
def setUp(self):
self.shape = [1, 16, 1, 16]
self.shape = [1, 8, 2, 16]
self.dtype = 'float32'
self.training = True
self.seed = 1203
......@@ -138,7 +146,7 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase):
return fw, bw
def test_fused_dropout_add(self):
def test_fused_rope(self):
p_fw, p_bw = self.get_forward_backward(
paddle_fused_rotary_position_embedding, seed=self.seed
)
......@@ -153,7 +161,7 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase):
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
)
def test_fused_dropout_add_sin_cos(self):
def test_fused_rope_with_sin_cos(self):
p_fw, p_bw = self.get_forward_backward(
paddle_fused_rotary_position_embedding, seed=self.seed
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册