Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
635958d9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
635958d9
编写于
11月 18, 2022
作者:
F
feng_shuai
提交者:
GitHub
11月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize: vectorize transpose_padding (#48116)
上级
42f35841
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
78 addition
and
22 deletion
+78
-22
paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
.../fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
+78
-22
未找到文件。
paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
浏览文件 @
635958d9
...
@@ -78,8 +78,6 @@ __global__ void transpose_qkv_padding(
...
@@ -78,8 +78,6 @@ __global__ void transpose_qkv_padding(
qkv_id
*
head_num
*
size_per_head
+
head_id
*
size_per_head
;
qkv_id
*
head_num
*
size_per_head
+
head_id
*
size_per_head
;
if
(
seq_id
<
real_seq_len
)
{
if
(
seq_id
<
real_seq_len
)
{
dst
[
threadIdx
.
x
+
dst_offset
]
=
src
[
threadIdx
.
x
+
src_offset
];
dst
[
threadIdx
.
x
+
dst_offset
]
=
src
[
threadIdx
.
x
+
src_offset
];
}
else
if
(
seq_id
<
seq_len
)
{
dst
[
threadIdx
.
x
+
dst_offset
]
=
0
;
}
}
}
}
...
@@ -91,14 +89,69 @@ __global__ void transpose_qkv_unpadding(const T *src,
...
@@ -91,14 +89,69 @@ __global__ void transpose_qkv_unpadding(const T *src,
const
int
head_num
,
const
int
head_num
,
const
int
size_per_head
,
const
int
size_per_head
,
const
int
real_seq_len
)
{
const
int
real_seq_len
)
{
int
batch_id
=
blockIdx
.
x
/
(
head_num
*
real_seq_len
);
int
batch_id
=
blockIdx
.
y
;
int
seq_id
=
blockIdx
.
x
%
real_seq_len
;
int
seq_id
=
blockIdx
.
x
;
int
head_id
=
blockIdx
.
x
%
(
head_num
*
real_seq_len
)
/
real_seq_len
;
int
head_id
=
threadIdx
.
y
;
dst
[
batch_id
*
head_num
*
real_seq_len
*
size_per_head
+
const
int
src_offset
=
batch_id
*
head_num
*
seq_len
*
size_per_head
+
seq_id
*
head_num
*
size_per_head
+
head_id
*
size_per_head
+
threadIdx
.
x
]
=
src
[
batch_id
*
head_num
*
seq_len
*
size_per_head
+
head_id
*
seq_len
*
size_per_head
+
head_id
*
seq_len
*
size_per_head
+
seq_id
*
size_per_head
+
threadIdx
.
x
];
seq_id
*
size_per_head
;
const
int
dst_offset
=
batch_id
*
real_seq_len
*
head_num
*
size_per_head
+
seq_id
*
head_num
*
size_per_head
+
head_id
*
size_per_head
;
dst
[
threadIdx
.
x
+
dst_offset
]
=
src
[
threadIdx
.
x
+
src_offset
];
}
#define LAUNCH_TRANSPOSE_KERNEL(TYPE, VECTOR_SIZE, PAD_TYPE) \
do { \
int h = head_size / VECTOR_SIZE; \
const TYPE *input##VECTOR_SIZE = reinterpret_cast<const TYPE *>(input); \
TYPE *output##VECTOR_SIZE = reinterpret_cast<TYPE *>(output); \
dim3 block(h, head_num, 1); \
transpose_qkv_##PAD_TYPE<TYPE> \
<<<grid, block, 0, stream>>>(input##VECTOR_SIZE, \
output##VECTOR_SIZE, \
batch, \
seq_len, \
head_num, \
h, \
real_seq_len); \
} while (0)
inline
void
TransposePadding
(
const
half
*
input
,
half
*
output
,
const
int
batch
,
const
int
seq_len
,
const
int
head_num
,
const
int
head_size
,
const
int
real_seq_len
,
cudaStream_t
stream
)
{
const
dim3
grid
(
seq_len
,
batch
,
3
);
if
(
head_size
%
8
==
0
)
{
LAUNCH_TRANSPOSE_KERNEL
(
int4
,
8
,
padding
);
}
else
if
(
head_size
%
2
==
0
)
{
LAUNCH_TRANSPOSE_KERNEL
(
half2
,
2
,
padding
);
}
else
{
LAUNCH_TRANSPOSE_KERNEL
(
half
,
1
,
padding
);
}
}
inline
void
TransposeUnPadding
(
const
half
*
input
,
half
*
output
,
const
int
batch
,
const
int
seq_len
,
const
int
head_num
,
const
int
head_size
,
const
int
real_seq_len
,
cudaStream_t
stream
)
{
const
dim3
grid
(
real_seq_len
,
batch
);
if
(
head_size
%
8
==
0
)
{
LAUNCH_TRANSPOSE_KERNEL
(
int4
,
8
,
unpadding
);
}
else
if
(
head_size
%
2
==
0
)
{
LAUNCH_TRANSPOSE_KERNEL
(
half2
,
2
,
unpadding
);
}
else
{
LAUNCH_TRANSPOSE_KERNEL
(
half
,
1
,
unpadding
);
}
}
}
int
QkvToContextPluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
int
QkvToContextPluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
...
@@ -381,15 +434,14 @@ int QkvToContextPluginDynamic::enqueue(
...
@@ -381,15 +434,14 @@ int QkvToContextPluginDynamic::enqueue(
const
half
*
input1_data
=
static_cast
<
const
half
*>
(
qk_bias
);
const
half
*
input1_data
=
static_cast
<
const
half
*>
(
qk_bias
);
// BxSx3xNxH => tptr: 3xBxNxSxH.
// BxSx3xNxH => tptr: 3xBxNxSxH.
if
(
need_padding
)
{
if
(
need_padding
)
{
dim3
grid_p
(
seq_len
,
batch
,
3
);
TransposePadding
(
input0_data
,
dim3
block_p
(
head_size_
,
head_number_
,
1
);
transpose_qkv_padding
<<<
grid_p
,
block_p
,
0
,
stream
>>>
(
input0_data
,
tptr
,
tptr
,
batch
,
batch
,
seq_len
,
seq_len
,
head_number_
,
head_number_
,
head_size_
,
head_size_
,
real_seq_len
);
real_seq_len
,
stream
);
}
else
{
}
else
{
TransposeQKV
(
TransposeQKV
(
batch
,
seq_len
,
head_size_
,
head_number_
,
input0_data
,
tptr
,
stream
);
batch
,
seq_len
,
head_size_
,
head_number_
,
input0_data
,
tptr
,
stream
);
...
@@ -424,10 +476,14 @@ int QkvToContextPluginDynamic::enqueue(
...
@@ -424,10 +476,14 @@ int QkvToContextPluginDynamic::enqueue(
int
block
=
head_size_
;
int
block
=
head_size_
;
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
if
(
need_padding
)
{
if
(
need_padding
)
{
int
grid_u
=
batch
*
head_number_
*
real_seq_len
;
TransposeUnPadding
(
tptr
,
int
block_u
=
head_size_
;
output
,
transpose_qkv_unpadding
<
half
><<<
grid_u
,
block_u
,
0
,
stream
>>>
(
batch
,
tptr
,
output
,
batch
,
seq_len
,
head_number_
,
head_size_
,
real_seq_len
);
seq_len
,
head_number_
,
head_size_
,
real_seq_len
,
stream
);
}
else
{
}
else
{
transpose
<
half
><<<
grid
,
block
,
0
,
stream
>>>
(
transpose
<
half
><<<
grid
,
block
,
0
,
stream
>>>
(
tptr
,
output
,
batch
,
seq_len
,
head_number_
,
head_size_
);
tptr
,
output
,
batch
,
seq_len
,
head_number_
,
head_size_
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录