未验证 提交 8d181e37 编写于 作者: N niuliling123 提交者: GitHub

change index's dtype for int to int64 (#55949)

上级 4eba6478
......@@ -32,10 +32,9 @@ void FusedRopeGradKernel(const Context& dev_ctx,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
int numel = dout_q.numel();
int64_t numel = dout_q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(dq);
dq->Resize(dout_q.dims());
// small size for broadcast
auto batch_size = dout_q.dims()[0];
auto num_heads = dout_q.dims()[2];
......@@ -51,8 +50,8 @@ void FusedRopeGradKernel(const Context& dev_ctx,
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
int grid = config.block_per_grid.x;
int block = config.thread_per_block.x;
int64_t grid = config.block_per_grid.x;
int64_t block = config.thread_per_block.x;
auto stream = dev_ctx.stream();
phi::Array<T*, 3> outs_data;
......@@ -65,7 +64,6 @@ void FusedRopeGradKernel(const Context& dev_ctx,
if (dout_k.get_ptr()) {
dev_ctx.template Alloc<T>(dk);
dk->Resize(dout_q.dims());
outs_data[1] = dk->data<T>();
ins_data[1] = dout_k->data<T>();
num_inputs++;
......@@ -73,7 +71,6 @@ void FusedRopeGradKernel(const Context& dev_ctx,
if (dout_v.get_ptr()) {
dev_ctx.template Alloc<T>(dv);
dv->Resize(dout_q.dims());
outs_data[2] = dv->data<T>();
ins_data[2] = dout_v->data<T>();
num_inputs++;
......
......@@ -32,10 +32,9 @@ void FusedRopeKernel(const Context& dev_ctx,
DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v) {
int numel = q.numel();
int64_t numel = q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(out_q);
out_q->Resize(q.dims());
// small size for broadcast
auto batch_size = q.dims()[0];
auto num_heads = q.dims()[2];
......@@ -51,8 +50,8 @@ void FusedRopeKernel(const Context& dev_ctx,
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
int grid = config.block_per_grid.x;
int block = config.thread_per_block.x;
int64_t grid = config.block_per_grid.x;
int64_t block = config.thread_per_block.x;
auto stream = dev_ctx.stream();
phi::Array<T*, 3> outs_data;
......@@ -65,7 +64,6 @@ void FusedRopeKernel(const Context& dev_ctx,
if (k.get_ptr()) {
dev_ctx.template Alloc<T>(out_k);
out_k->Resize(q.dims());
ins_data[1] = k->data<T>();
outs_data[1] = out_k->data<T>();
num_inputs++;
......@@ -73,7 +71,6 @@ void FusedRopeKernel(const Context& dev_ctx,
if (v.get_ptr()) {
dev_ctx.template Alloc<T>(out_v);
out_v->Resize(q.dims());
ins_data[2] = v->data<T>();
outs_data[2] = out_v->data<T>();
num_inputs++;
......
......@@ -24,16 +24,20 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
phi::Array<const T*, 2> sin_cos_data,
bool flag_sin_cos,
int sign,
int batch_size,
int seq_len,
int num_heads,
int head_dim,
int64_t batch_size,
int64_t seq_len,
int64_t num_heads,
int64_t head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int stride = gridDim.x * blockDim.x * VecSize;
int size = batch_size * seq_len * num_heads * head_dim;
int64_t index =
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
threadIdx.x) *
VecSize;
int64_t stride = static_cast<int64_t>(gridDim.x) *
static_cast<int64_t>(blockDim.x) * VecSize;
int64_t size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize];
MPType cos_value[VecSize];
MPType result[VecSize];
......@@ -44,11 +48,11 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
for (; index < size; index += stride) {
if (flag_sin_cos) {
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
int pos_head = index_wc % head_dim;
int index_sc = pos_seq * head_dim + pos_head;
for (int64_t nx = 0; nx < VecSize; ++nx) {
int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int64_t pos_seq = index_wc / (num_heads * head_dim);
int64_t pos_head = index_wc % head_dim;
int64_t index_sc = pos_seq * head_dim + pos_head;
const T* sin_input = sin_cos_data[0] + index_sc;
const T* cos_input = sin_cos_data[1] + index_sc;
......@@ -59,8 +63,8 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int64_t pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
......
......@@ -17,7 +17,7 @@ from paddle import _C_ops
from paddle.framework import in_dynamic_mode
def fused_rotary_position_embedding(q, k, v, sin=None, cos=None):
def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None):
r"""
Fused rotary position embedding.
......@@ -53,3 +53,7 @@ def fused_rotary_position_embedding(q, k, v, sin=None, cos=None):
"""
if in_dynamic_mode():
return _C_ops.fused_rotary_position_embedding(q, k, v, sin, cos)
raise RuntimeError(
"This feature is currently supported only in dynamic mode and with CUDAPlace."
)
......@@ -168,6 +168,15 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase):
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
)
def test_error(self):
paddle.enable_static()
with self.assertRaises(RuntimeError):
static_q = paddle.static.data(
name="q", shape=self.shape, dtype=self.dtype
)
fused_rotary_position_embedding(static_q, static_q, static_q)
paddle.disable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册