Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
11b9d85f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
11b9d85f
编写于
11月 28, 2022
作者:
W
Wang Bojun
提交者:
GitHub
11月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix: multihead matmul biasqk broadcast support for [1,1,seq,seq] shape (#47975)
* add trt support
上级
57e22f58
变更
5
展开全部
显示空白变更内容
内联
并排
Showing
5 changed file
with
606 addition
and
6 deletion
+606
-6
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+11
-5
paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
.../fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
+45
-0
paddle/fluid/operators/fused/multihead_matmul_op.cu
paddle/fluid/operators/fused/multihead_matmul_op.cu
+29
-1
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py
...ittests/ir/inference/test_trt_convert_multihead_matmul.py
+414
-0
python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py
...e/fluid/tests/unittests/test_fused_multihead_matmul_op.py
+107
-0
未找到文件。
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
11b9d85f
...
@@ -1744,13 +1744,19 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -1744,13 +1744,19 @@ struct SimpleOpTypeSetTeller : public Teller {
input_shape
[
1
]
==
biasqk_shape
[
3
];
input_shape
[
1
]
==
biasqk_shape
[
3
];
bool
is_broadcastable
=
biasqk_shape
[
1
]
==
1
&&
biasqk_shape
[
2
]
==
1
&&
bool
is_broadcastable
=
biasqk_shape
[
1
]
==
1
&&
biasqk_shape
[
2
]
==
1
&&
input_shape
[
1
]
==
biasqk_shape
[
3
];
input_shape
[
1
]
==
biasqk_shape
[
3
];
is_broadcastable
=
is_broadcastable
||
(
biasqk_shape
[
0
]
==
1
&&
biasqk_shape
[
1
]
==
1
&&
input_shape
[
1
]
==
biasqk_shape
[
2
]
&&
input_shape
[
1
]
==
biasqk_shape
[
3
]);
if
(
!
(
has_same_shape
||
is_broadcastable
))
{
if
(
!
(
has_same_shape
||
is_broadcastable
))
{
VLOG
(
3
)
<<
"The BiasQK's shape is invalid, expect ["
<<
input_shape
[
0
]
VLOG
(
3
)
<<
"The BiasQK's shape is invalid, expect ["
<<
input_shape
[
0
]
<<
", 1, 1, "
<<
input_shape
[
1
]
<<
"] or ["
<<
input_shape
[
0
]
<<
", 1, 1, "
<<
input_shape
[
1
]
<<
"] "
<<
", "
<<
head_number
<<
", "
<<
input_shape
[
1
]
<<
", "
<<
"or ["
<<
input_shape
[
0
]
<<
", "
<<
head_number
<<
", "
<<
input_shape
[
1
]
<<
"] but ["
<<
biasqk_shape
[
0
]
<<
", "
<<
input_shape
[
1
]
<<
", "
<<
input_shape
[
1
]
<<
"] "
<<
biasqk_shape
[
1
]
<<
", "
<<
biasqk_shape
[
2
]
<<
", "
<<
"or ["
<<
input_shape
[
0
]
<<
"/1, "
<<
1
<<
", "
<<
biasqk_shape
[
3
]
<<
"]."
;
<<
input_shape
[
1
]
<<
", "
<<
input_shape
[
1
]
<<
"] "
<<
"but got ["
<<
biasqk_shape
[
0
]
<<
", "
<<
biasqk_shape
[
1
]
<<
", "
<<
biasqk_shape
[
2
]
<<
", "
<<
biasqk_shape
[
3
]
<<
"]."
;
return
false
;
return
false
;
}
}
}
else
{
}
else
{
...
...
paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
浏览文件 @
11b9d85f
...
@@ -309,6 +309,19 @@ __global__ void broadcast(const T *src,
...
@@ -309,6 +309,19 @@ __global__ void broadcast(const T *src,
}
}
}
}
template
<
typename
T
>
__global__
void
broadcast_batch_head_number
(
const
T
*
src
,
T
*
dst
,
const
int
batch_size
,
const
int
seq_len
,
const
int
head_num
)
{
int
batch_id
=
blockIdx
.
x
%
seq_len
;
int
dst_offset
=
blockIdx
.
x
*
seq_len
;
if
(
threadIdx
.
x
<
seq_len
)
{
dst
[
threadIdx
.
x
+
dst_offset
]
=
src
[
threadIdx
.
x
+
batch_id
*
seq_len
];
}
}
int
QkvToContextPluginDynamic
::
enqueue
(
int
QkvToContextPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
...
@@ -353,6 +366,22 @@ int QkvToContextPluginDynamic::enqueue(
...
@@ -353,6 +366,22 @@ int QkvToContextPluginDynamic::enqueue(
head_number_
);
head_number_
);
qk_bias
=
temp_qk_bias
;
qk_bias
=
temp_qk_bias
;
}
}
// fit to [batch, head_num, length, length] + [1, 1, length, length]
if
(
ProductDim
(
input_desc
[
1
].
dims
)
==
(
seq_len
*
seq_len
))
{
temp_qk_bias_tensor
.
Resize
({
batch
,
head_number_
,
seq_len
,
seq_len
});
auto
*
temp_qk_bias
=
reinterpret_cast
<
float
*>
(
temp_qk_bias_tensor
.
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
device_id
)));
int
grid
=
batch
*
head_number_
*
seq_len
;
int
block
=
round_up
(
seq_len
);
broadcast_batch_head_number
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
float
*>
(
inputs
[
1
]),
temp_qk_bias
,
batch
,
seq_len
,
head_number_
);
qk_bias
=
temp_qk_bias
;
}
// fake qk_bias
// fake qk_bias
if
(
ProductDim
(
input_desc
[
1
].
dims
)
==
ProductDim
(
input_desc
[
0
].
dims
))
{
if
(
ProductDim
(
input_desc
[
1
].
dims
)
==
ProductDim
(
input_desc
[
0
].
dims
))
{
qk_bias
=
fake_qk_bias_
;
qk_bias
=
fake_qk_bias_
;
...
@@ -424,6 +453,22 @@ int QkvToContextPluginDynamic::enqueue(
...
@@ -424,6 +453,22 @@ int QkvToContextPluginDynamic::enqueue(
head_number_
);
head_number_
);
qk_bias
=
temp_qk_bias
;
qk_bias
=
temp_qk_bias
;
}
}
// fit to [batch, head_num, length, length] + [1, 1, length, length]
if
(
ProductDim
(
input_desc
[
1
].
dims
)
==
(
seq_len
*
seq_len
))
{
temp_qk_bias_tensor
.
Resize
({
batch
,
head_number_
,
seq_len
,
seq_len
});
auto
*
temp_qk_bias
=
reinterpret_cast
<
half
*>
(
temp_qk_bias_tensor
.
mutable_data
<
int16_t
>
(
platform
::
CUDAPlace
(
device_id
)));
int
grid
=
batch
*
head_number_
*
seq_len
;
int
block
=
round_up
(
seq_len
);
broadcast_batch_head_number
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
half
*>
(
inputs
[
1
]),
temp_qk_bias
,
batch
,
seq_len
,
head_number_
);
qk_bias
=
temp_qk_bias
;
}
// padding: mask_half_ = [1.0,....1.0...1.0....,0.0f]
// padding: mask_half_ = [1.0,....1.0...1.0....,0.0f]
// no_padding: mask_half_ = [1.0,....1.0,.........,1.0f]
// no_padding: mask_half_ = [1.0,....1.0,.........,1.0f]
bool
bias_is_mask
=
false
;
bool
bias_is_mask
=
false
;
...
...
paddle/fluid/operators/fused/multihead_matmul_op.cu
浏览文件 @
11b9d85f
...
@@ -256,6 +256,19 @@ __global__ void broadcast(const T *src,
...
@@ -256,6 +256,19 @@ __global__ void broadcast(const T *src,
}
}
}
}
template
<
typename
T
>
__global__
void
broadcast_batch_head_number
(
const
T
*
src
,
T
*
dst
,
const
int
batch_size
,
const
int
seq_len
,
const
int
head_num
)
{
int
src_seq_id
=
blockIdx
.
x
%
seq_len
;
int
dst_offset
=
blockIdx
.
x
*
seq_len
;
if
(
threadIdx
.
x
<
seq_len
)
{
dst
[
threadIdx
.
x
+
dst_offset
]
=
src
[
threadIdx
.
x
+
src_seq_id
*
seq_len
];
}
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
MultiHeadMatMulV2Kernel
:
public
framework
::
OpKernel
<
T
>
{
class
MultiHeadMatMulV2Kernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -286,6 +299,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
...
@@ -286,6 +299,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
Tensor
temp_bias_tensor
;
Tensor
temp_bias_tensor
;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if
(
bias_qk
&&
bias_qk
->
numel
()
==
(
batch
*
seq_len
))
{
if
(
bias_qk
&&
bias_qk
->
numel
()
==
(
batch
*
seq_len
))
{
VLOG
(
4
)
<<
"Do broadcasted bias_qk from [batch, 1, 1, seq_len]"
;
temp_bias_tensor
.
Resize
({
batch
*
head_number
*
seq_len
*
seq_len
});
temp_bias_tensor
.
Resize
({
batch
*
head_number
*
seq_len
*
seq_len
});
auto
*
temp_qk_bias
=
device_ctx
.
template
Alloc
<
T
>(
auto
*
temp_qk_bias
=
device_ctx
.
template
Alloc
<
T
>(
&
temp_bias_tensor
,
temp_bias_tensor
.
numel
()
*
sizeof
(
T
));
&
temp_bias_tensor
,
temp_bias_tensor
.
numel
()
*
sizeof
(
T
));
...
@@ -295,6 +309,19 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
...
@@ -295,6 +309,19 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
bias_qk_d
,
temp_qk_bias
,
seq_len
,
head_number
);
bias_qk_d
,
temp_qk_bias
,
seq_len
,
head_number
);
bias_qk_d
=
static_cast
<
const
T
*>
(
temp_qk_bias
);
bias_qk_d
=
static_cast
<
const
T
*>
(
temp_qk_bias
);
}
}
// if bias_qk is[1, 1, seq_len, seq_len], the bias_qk_d need to be
// broadcasted
if
(
bias_qk
&&
bias_qk
->
numel
()
==
(
1
*
seq_len
*
seq_len
))
{
VLOG
(
4
)
<<
"do broadcasted bias_qk from [1, 1, seq_len, seq_len]"
;
temp_bias_tensor
.
Resize
({
batch
*
head_number
*
seq_len
*
seq_len
});
auto
*
temp_qk_bias
=
device_ctx
.
template
Alloc
<
T
>(
&
temp_bias_tensor
,
temp_bias_tensor
.
numel
()
*
sizeof
(
T
));
int
grid
=
batch
*
head_number
*
seq_len
;
int
block
=
round_up
(
seq_len
);
broadcast_batch_head_number
<<<
grid
,
block
,
0
,
stream
>>>
(
bias_qk_d
,
temp_qk_bias
,
batch
,
seq_len
,
head_number
);
bias_qk_d
=
static_cast
<
const
T
*>
(
temp_qk_bias
);
}
if
(
!
bias_qk
)
{
if
(
!
bias_qk
)
{
int
size
=
batch
*
head_number
*
seq_len
*
seq_len
;
int
size
=
batch
*
head_number
*
seq_len
*
seq_len
;
temp_bias_tensor
.
Resize
({
size
});
temp_bias_tensor
.
Resize
({
size
});
...
@@ -333,7 +360,8 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
...
@@ -333,7 +360,8 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
// (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
// (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
auto
blas
=
phi
::
funcs
::
GetBlas
<
phi
::
GPUContext
,
T
>
(
device_ctx
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
phi
::
GPUContext
,
T
>
(
device_ctx
);
blas
.
MatMul
(
input_matrix
,
w_matrix
,
&
temp_out_tensor
);
blas
.
MatMul
(
input_matrix
,
w_matrix
,
&
temp_out_tensor
);
VLOG
(
2
)
<<
"(B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)"
;
VLOG
(
2
)
<<
temp_out_tensor
;
// temp_out_tensor.Resize(temp_out_dims);
// temp_out_tensor.Resize(temp_out_dims);
Tensor
multihead_temp_tensor
;
Tensor
multihead_temp_tensor
;
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py
浏览文件 @
11b9d85f
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py
浏览文件 @
11b9d85f
...
@@ -29,6 +29,113 @@ def stable_softmax(x):
...
@@ -29,6 +29,113 @@ def stable_softmax(x):
return
exps
/
np
.
sum
(
exps
)
return
exps
/
np
.
sum
(
exps
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"Paddle core is not compiled with CUDA"
)
class
TestFusedMultiHeadMatmulOp_biasqk2
(
OpTest
):
def
config
(
self
):
self
.
seq_len
=
128
self
.
size_per_head
=
64
self
.
head_number
=
12
self
.
batch_size
=
8
self
.
scale
=
0.125
def
setUp
(
self
):
self
.
op_type
=
"multihead_matmul"
self
.
config
()
h
=
self
.
seq_len
w
=
self
.
head_number
*
self
.
size_per_head
self
.
Input
=
(
np
.
random
.
random
((
self
.
batch_size
,
h
,
w
)).
astype
(
"float32"
)
-
0.5
)
self
.
WQ
=
np
.
random
.
random
((
w
,
w
)).
astype
(
"float32"
)
self
.
KQ
=
np
.
random
.
random
((
w
,
w
)).
astype
(
"float32"
)
self
.
VQ
=
np
.
random
.
random
((
w
,
w
)).
astype
(
"float32"
)
self
.
CombinedW
=
np
.
hstack
((
self
.
WQ
,
self
.
KQ
,
self
.
VQ
)).
reshape
(
(
w
,
3
,
w
)
)
self
.
Q
=
np
.
dot
(
self
.
Input
,
self
.
WQ
)
self
.
K
=
np
.
dot
(
self
.
Input
,
self
.
KQ
)
self
.
V
=
np
.
dot
(
self
.
Input
,
self
.
VQ
)
self
.
BiasQ
=
np
.
random
.
random
((
1
,
w
)).
astype
(
"float32"
)
self
.
BiasK
=
np
.
random
.
random
((
1
,
w
)).
astype
(
"float32"
)
self
.
BiasV
=
np
.
random
.
random
((
1
,
w
)).
astype
(
"float32"
)
self
.
CombinedB
=
np
.
vstack
((
self
.
BiasQ
,
self
.
BiasK
,
self
.
BiasV
))
self
.
BiasQK
=
np
.
random
.
random
(
(
1
,
1
,
self
.
seq_len
,
self
.
seq_len
)
).
astype
(
"float32"
)
# Compute Q path
fc_q
=
self
.
Q
+
self
.
BiasQ
reshape_q
=
np
.
reshape
(
fc_q
,
(
self
.
batch_size
,
self
.
seq_len
,
self
.
head_number
,
self
.
size_per_head
,
),
)
transpose_q
=
np
.
transpose
(
reshape_q
,
(
0
,
2
,
1
,
3
))
scale_q
=
self
.
scale
*
transpose_q
# Compute K path
fc_k
=
self
.
K
+
self
.
BiasK
reshape_k
=
np
.
reshape
(
fc_k
,
(
self
.
batch_size
,
self
.
seq_len
,
self
.
head_number
,
self
.
size_per_head
,
),
)
transpose_k
=
np
.
transpose
(
reshape_k
,
(
0
,
2
,
3
,
1
))
# Compute Q*K
q_k
=
np
.
matmul
(
scale_q
,
transpose_k
)
eltadd_qk
=
q_k
+
np
.
tile
(
self
.
BiasQK
,
[
self
.
batch_size
,
self
.
head_number
,
1
,
1
]
)
softmax_qk
=
np
.
apply_along_axis
(
stable_softmax
,
3
,
eltadd_qk
)
# Compute V path
fc_v
=
self
.
V
+
self
.
BiasV
reshape_v
=
np
.
reshape
(
fc_v
,
(
self
.
batch_size
,
self
.
seq_len
,
self
.
head_number
,
self
.
size_per_head
,
),
)
transpose_v
=
np
.
transpose
(
reshape_v
,
(
0
,
2
,
1
,
3
))
# Compute QK*V
qkv
=
np
.
matmul
(
softmax_qk
,
transpose_v
)
transpose_qkv
=
np
.
transpose
(
qkv
,
(
0
,
2
,
1
,
3
))
reshape_qkv
=
np
.
reshape
(
transpose_qkv
,
(
self
.
batch_size
,
h
,
w
))
print
(
"biasqk shape"
)
print
(
self
.
BiasQK
.
shape
)
self
.
inputs
=
{
"Input"
:
self
.
Input
,
"W"
:
self
.
CombinedW
,
"Bias"
:
self
.
CombinedB
,
"BiasQK"
:
self
.
BiasQK
,
}
self
.
attrs
=
{
"transpose_Q"
:
False
,
"transpose_K"
:
True
,
"transpose_V"
:
False
,
"head_number"
:
self
.
head_number
,
"alpha"
:
self
.
scale
,
}
self
.
outputs
=
{
"Out"
:
reshape_qkv
}
def
test_check_output
(
self
):
place
=
core
.
CUDAPlace
(
0
)
self
.
check_output_with_place
(
place
,
atol
=
2e-3
)
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"Paddle core is not compiled with CUDA"
not
core
.
is_compiled_with_cuda
(),
"Paddle core is not compiled with CUDA"
)
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录