Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
cf9eae4c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cf9eae4c
编写于
9月 17, 2021
作者:
F
feng_shuai
提交者:
GitHub
9月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
broadcast qkv_op (#35780)
* broadcast qkv_op * use PADDLE_ENFORCE_GT to replace assert
上级
7975dfcf
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
79 addition
and
4 deletion
+79
-4
paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
.../fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
+49
-2
paddle/fluid/operators/fused/multihead_matmul_op.cu
paddle/fluid/operators/fused/multihead_matmul_op.cu
+30
-2
未找到文件。
paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
浏览文件 @
cf9eae4c
...
...
@@ -233,6 +233,24 @@ __global__ void apply_scale(T *data, T scale, int n) {
#endif
}
inline
int
round_up
(
int
seq_len
,
int
multiple
=
32
)
{
PADDLE_ENFORCE_GT
(
multiple
,
0
,
platform
::
errors
::
InvalidArgument
(
"multiple should be a positive number,but it's (%d)"
,
multiple
));
return
((
seq_len
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
template
<
typename
T
>
__global__
void
broadcast
(
const
T
*
src
,
T
*
dst
,
const
int
seq_len
,
const
int
head_num
)
{
int
batch_id
=
blockIdx
.
x
/
(
head_num
*
seq_len
);
int
dst_offset
=
blockIdx
.
x
*
seq_len
;
if
(
threadIdx
.
x
<
seq_len
)
{
dst
[
threadIdx
.
x
+
dst_offset
]
=
src
[
threadIdx
.
x
+
batch_id
*
seq_len
];
}
}
int
QkvToContextPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
...
...
@@ -258,7 +276,21 @@ int QkvToContextPluginDynamic::enqueue(
auto
*
tptr
=
multihead_temp_data
+
scratch_size
;
const
float
*
input0_data
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
const
float
*
input1_data
=
static_cast
<
const
float
*>
(
inputs
[
1
]);
// fit to [batch, head_num, length, length] + [batch, 1, 1, length]
framework
::
Tensor
temp_qk_bias_tensor
;
float
*
qk_bias
=
const_cast
<
float
*>
(
static_cast
<
const
float
*>
(
inputs
[
1
]));
if
(
ProductDim
(
input_desc
[
1
].
dims
)
==
(
batch
*
seq_len
))
{
temp_qk_bias_tensor
.
Resize
({
batch
,
head_number_
,
seq_len
,
seq_len
});
auto
*
temp_qk_bias
=
temp_qk_bias_tensor
.
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
device_id
));
int
grid
=
batch
*
head_number_
*
seq_len
;
int
block
=
round_up
(
seq_len
);
broadcast
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
float
*>
(
inputs
[
1
]),
temp_qk_bias
,
seq_len
,
head_number_
);
qk_bias
=
temp_qk_bias
;
}
const
float
*
input1_data
=
static_cast
<
const
float
*>
(
qk_bias
);
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV
(
batch
,
seq_len
,
head_size_
,
head_number_
,
input0_data
,
tptr
,
stream
);
...
...
@@ -290,7 +322,22 @@ int QkvToContextPluginDynamic::enqueue(
half
*
tptr
=
qkptr
+
scratch_size
;
const
half
*
input0_data
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
half
*
input1_data
=
static_cast
<
const
half
*>
(
inputs
[
1
]);
// fit to [batch, head_num, length, length] + [batch, 1, 1, length]
framework
::
Tensor
temp_qk_bias_tensor
;
half
*
qk_bias
=
const_cast
<
half
*>
(
static_cast
<
const
half
*>
(
inputs
[
1
]));
if
(
ProductDim
(
input_desc
[
1
].
dims
)
==
(
batch
*
seq_len
))
{
temp_qk_bias_tensor
.
Resize
({
batch
,
head_number_
,
seq_len
,
seq_len
});
auto
*
temp_qk_bias
=
reinterpret_cast
<
half
*>
(
temp_qk_bias_tensor
.
mutable_data
<
int16_t
>
(
platform
::
CUDAPlace
(
device_id
)));
int
grid
=
batch
*
head_number_
*
seq_len
;
int
block
=
round_up
(
seq_len
);
broadcast
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
half
*>
(
inputs
[
1
]),
temp_qk_bias
,
seq_len
,
head_number_
);
qk_bias
=
temp_qk_bias
;
}
const
half
*
input1_data
=
static_cast
<
const
half
*>
(
qk_bias
);
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV
(
batch
,
seq_len
,
head_size_
,
head_number_
,
input0_data
,
tptr
,
stream
);
...
...
paddle/fluid/operators/fused/multihead_matmul_op.cu
浏览文件 @
cf9eae4c
...
...
@@ -132,6 +132,24 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
}
}
inline
int
round_up
(
int
seq_len
,
int
multiple
=
32
)
{
PADDLE_ENFORCE_GT
(
multiple
,
0
,
platform
::
errors
::
InvalidArgument
(
"multiple should be a positive number,but it's (%d)"
,
multiple
));
return
((
seq_len
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
template
<
typename
T
>
__global__
void
broadcast
(
const
T
*
src
,
T
*
dst
,
const
int
seq_len
,
const
int
head_num
)
{
int
batch_id
=
blockIdx
.
x
/
(
head_num
*
seq_len
);
int
dst_offset
=
blockIdx
.
x
*
seq_len
;
if
(
threadIdx
.
x
<
seq_len
)
{
dst
[
threadIdx
.
x
+
dst_offset
]
=
src
[
threadIdx
.
x
+
batch_id
*
seq_len
];
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
MultiHeadMatMulV2Kernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -152,6 +170,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
int
head_number
=
context
.
Attr
<
int
>
(
"head_number"
);
// compute q*k with eltadd
auto
&
device_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
stream
=
device_ctx
.
stream
();
// should be (B * S * hidden)
auto
input_dims
=
input
->
dims
();
// shouble be (hidden * 3 * all_head_size)
...
...
@@ -159,7 +178,17 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
int
batch
=
input_dims
[
0
];
int
seq_len
=
input_dims
[
1
];
int
hidden
=
input_dims
[
2
];
Tensor
temp_bias_tensor
;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if
(
bias_qk
.
numel
()
==
(
batch
*
seq_len
))
{
temp_bias_tensor
.
Resize
({
batch
*
head_number
*
seq_len
*
seq_len
});
auto
*
temp_qk_bias
=
temp_bias_tensor
.
mutable_data
<
T
>
(
context
.
GetPlace
());
int
grid
=
batch
*
head_number
*
seq_len
;
int
block
=
round_up
(
seq_len
);
broadcast
<<<
grid
,
block
,
0
,
stream
>>>
(
bias_qk_d
,
temp_qk_bias
,
seq_len
,
head_number
);
bias_qk_d
=
static_cast
<
const
T
*>
(
temp_qk_bias
);
}
int
all_head_size
=
w_dims
[
2
];
int
head_size
=
all_head_size
/
head_number
;
...
...
@@ -196,7 +225,6 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
auto
*
qkptr
=
multihead_temp_data
;
auto
*
tptr
=
multihead_temp_data
+
scratch_size
;
auto
stream
=
device_ctx
.
stream
();
// Do the transpose with bias.
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransQKVWithBias
(
batch
,
seq_len
,
head_size
,
head_number
,
temp_out_data
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录