Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
cf9eae4c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cf9eae4c
编写于
9月 17, 2021
作者:
F
feng_shuai
提交者:
GitHub
9月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
broadcast qkv_op (#35780)
* broadcast qkv_op * use PADDLE_ENFORCE_GT to replace assert
上级
7975dfcf
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
79 addition
and
4 deletion
+79
-4
paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
.../fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
+49
-2
paddle/fluid/operators/fused/multihead_matmul_op.cu
paddle/fluid/operators/fused/multihead_matmul_op.cu
+30
-2
未找到文件。
paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
浏览文件 @
cf9eae4c
...
@@ -233,6 +233,24 @@ __global__ void apply_scale(T *data, T scale, int n) {
...
@@ -233,6 +233,24 @@ __global__ void apply_scale(T *data, T scale, int n) {
#endif
#endif
}
}
inline
int
round_up
(
int
seq_len
,
int
multiple
=
32
)
{
PADDLE_ENFORCE_GT
(
multiple
,
0
,
platform
::
errors
::
InvalidArgument
(
"multiple should be a positive number,but it's (%d)"
,
multiple
));
return
((
seq_len
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
template
<
typename
T
>
__global__
void
broadcast
(
const
T
*
src
,
T
*
dst
,
const
int
seq_len
,
const
int
head_num
)
{
int
batch_id
=
blockIdx
.
x
/
(
head_num
*
seq_len
);
int
dst_offset
=
blockIdx
.
x
*
seq_len
;
if
(
threadIdx
.
x
<
seq_len
)
{
dst
[
threadIdx
.
x
+
dst_offset
]
=
src
[
threadIdx
.
x
+
batch_id
*
seq_len
];
}
}
int
QkvToContextPluginDynamic
::
enqueue
(
int
QkvToContextPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
...
@@ -258,7 +276,21 @@ int QkvToContextPluginDynamic::enqueue(
...
@@ -258,7 +276,21 @@ int QkvToContextPluginDynamic::enqueue(
auto
*
tptr
=
multihead_temp_data
+
scratch_size
;
auto
*
tptr
=
multihead_temp_data
+
scratch_size
;
const
float
*
input0_data
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
const
float
*
input0_data
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
const
float
*
input1_data
=
static_cast
<
const
float
*>
(
inputs
[
1
]);
// fit to [batch, head_num, length, length] + [batch, 1, 1, length]
framework
::
Tensor
temp_qk_bias_tensor
;
float
*
qk_bias
=
const_cast
<
float
*>
(
static_cast
<
const
float
*>
(
inputs
[
1
]));
if
(
ProductDim
(
input_desc
[
1
].
dims
)
==
(
batch
*
seq_len
))
{
temp_qk_bias_tensor
.
Resize
({
batch
,
head_number_
,
seq_len
,
seq_len
});
auto
*
temp_qk_bias
=
temp_qk_bias_tensor
.
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
device_id
));
int
grid
=
batch
*
head_number_
*
seq_len
;
int
block
=
round_up
(
seq_len
);
broadcast
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
float
*>
(
inputs
[
1
]),
temp_qk_bias
,
seq_len
,
head_number_
);
qk_bias
=
temp_qk_bias
;
}
const
float
*
input1_data
=
static_cast
<
const
float
*>
(
qk_bias
);
// BxSx3xNxH => tptr: 3xBxNxSxH.
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV
(
batch
,
seq_len
,
head_size_
,
head_number_
,
input0_data
,
tptr
,
TransposeQKV
(
batch
,
seq_len
,
head_size_
,
head_number_
,
input0_data
,
tptr
,
stream
);
stream
);
...
@@ -290,7 +322,22 @@ int QkvToContextPluginDynamic::enqueue(
...
@@ -290,7 +322,22 @@ int QkvToContextPluginDynamic::enqueue(
half
*
tptr
=
qkptr
+
scratch_size
;
half
*
tptr
=
qkptr
+
scratch_size
;
const
half
*
input0_data
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
half
*
input0_data
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
half
*
input1_data
=
static_cast
<
const
half
*>
(
inputs
[
1
]);
// fit to [batch, head_num, length, length] + [batch, 1, 1, length]
framework
::
Tensor
temp_qk_bias_tensor
;
half
*
qk_bias
=
const_cast
<
half
*>
(
static_cast
<
const
half
*>
(
inputs
[
1
]));
if
(
ProductDim
(
input_desc
[
1
].
dims
)
==
(
batch
*
seq_len
))
{
temp_qk_bias_tensor
.
Resize
({
batch
,
head_number_
,
seq_len
,
seq_len
});
auto
*
temp_qk_bias
=
reinterpret_cast
<
half
*>
(
temp_qk_bias_tensor
.
mutable_data
<
int16_t
>
(
platform
::
CUDAPlace
(
device_id
)));
int
grid
=
batch
*
head_number_
*
seq_len
;
int
block
=
round_up
(
seq_len
);
broadcast
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
half
*>
(
inputs
[
1
]),
temp_qk_bias
,
seq_len
,
head_number_
);
qk_bias
=
temp_qk_bias
;
}
const
half
*
input1_data
=
static_cast
<
const
half
*>
(
qk_bias
);
// BxSx3xNxH => tptr: 3xBxNxSxH.
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV
(
batch
,
seq_len
,
head_size_
,
head_number_
,
input0_data
,
tptr
,
TransposeQKV
(
batch
,
seq_len
,
head_size_
,
head_number_
,
input0_data
,
tptr
,
stream
);
stream
);
...
...
paddle/fluid/operators/fused/multihead_matmul_op.cu
浏览文件 @
cf9eae4c
...
@@ -132,6 +132,24 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
...
@@ -132,6 +132,24 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
}
}
}
}
inline
int
round_up
(
int
seq_len
,
int
multiple
=
32
)
{
PADDLE_ENFORCE_GT
(
multiple
,
0
,
platform
::
errors
::
InvalidArgument
(
"multiple should be a positive number,but it's (%d)"
,
multiple
));
return
((
seq_len
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
template
<
typename
T
>
__global__
void
broadcast
(
const
T
*
src
,
T
*
dst
,
const
int
seq_len
,
const
int
head_num
)
{
int
batch_id
=
blockIdx
.
x
/
(
head_num
*
seq_len
);
int
dst_offset
=
blockIdx
.
x
*
seq_len
;
if
(
threadIdx
.
x
<
seq_len
)
{
dst
[
threadIdx
.
x
+
dst_offset
]
=
src
[
threadIdx
.
x
+
batch_id
*
seq_len
];
}
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
MultiHeadMatMulV2Kernel
:
public
framework
::
OpKernel
<
T
>
{
class
MultiHeadMatMulV2Kernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -152,6 +170,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
...
@@ -152,6 +170,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
int
head_number
=
context
.
Attr
<
int
>
(
"head_number"
);
int
head_number
=
context
.
Attr
<
int
>
(
"head_number"
);
// compute q*k with eltadd
// compute q*k with eltadd
auto
&
device_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
device_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
stream
=
device_ctx
.
stream
();
// should be (B * S * hidden)
// should be (B * S * hidden)
auto
input_dims
=
input
->
dims
();
auto
input_dims
=
input
->
dims
();
// shouble be (hidden * 3 * all_head_size)
// shouble be (hidden * 3 * all_head_size)
...
@@ -159,7 +178,17 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
...
@@ -159,7 +178,17 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
int
batch
=
input_dims
[
0
];
int
batch
=
input_dims
[
0
];
int
seq_len
=
input_dims
[
1
];
int
seq_len
=
input_dims
[
1
];
int
hidden
=
input_dims
[
2
];
int
hidden
=
input_dims
[
2
];
Tensor
temp_bias_tensor
;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if
(
bias_qk
.
numel
()
==
(
batch
*
seq_len
))
{
temp_bias_tensor
.
Resize
({
batch
*
head_number
*
seq_len
*
seq_len
});
auto
*
temp_qk_bias
=
temp_bias_tensor
.
mutable_data
<
T
>
(
context
.
GetPlace
());
int
grid
=
batch
*
head_number
*
seq_len
;
int
block
=
round_up
(
seq_len
);
broadcast
<<<
grid
,
block
,
0
,
stream
>>>
(
bias_qk_d
,
temp_qk_bias
,
seq_len
,
head_number
);
bias_qk_d
=
static_cast
<
const
T
*>
(
temp_qk_bias
);
}
int
all_head_size
=
w_dims
[
2
];
int
all_head_size
=
w_dims
[
2
];
int
head_size
=
all_head_size
/
head_number
;
int
head_size
=
all_head_size
/
head_number
;
...
@@ -196,7 +225,6 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
...
@@ -196,7 +225,6 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
auto
*
qkptr
=
multihead_temp_data
;
auto
*
qkptr
=
multihead_temp_data
;
auto
*
tptr
=
multihead_temp_data
+
scratch_size
;
auto
*
tptr
=
multihead_temp_data
+
scratch_size
;
auto
stream
=
device_ctx
.
stream
();
// Do the transpose with bias.
// Do the transpose with bias.
// BxSx3xNxH => tptr: 3xBxNxSxH.
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransQKVWithBias
(
batch
,
seq_len
,
head_size
,
head_number
,
temp_out_data
,
TransQKVWithBias
(
batch
,
seq_len
,
head_size
,
head_number
,
temp_out_data
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录