Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
472dcca4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
472dcca4
编写于
11月 08, 2021
作者:
L
Li Min
提交者:
GitHub
11月 08, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【fix-bug】Support attn_mask=None input cases for fused_attention_op. (#36951)
目前的fused_attention_op不支持attn_mask=None的输入,本PR对此进行了补充,并补充了相应的单测逻辑。
上级
b7e88308
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
178 addition
and
74 deletion
+178
-74
paddle/fluid/operators/fused/fmha_ref.h
paddle/fluid/operators/fused/fmha_ref.h
+11
-11
paddle/fluid/operators/fused/fused_attention_op.cc
paddle/fluid/operators/fused/fused_attention_op.cc
+25
-14
paddle/fluid/operators/fused/fused_attention_op.cu
paddle/fluid/operators/fused/fused_attention_op.cu
+11
-5
python/paddle/fluid/tests/unittests/test_fused_attention_op.py
...n/paddle/fluid/tests/unittests/test_fused_attention_op.py
+54
-11
python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py
...ddle/fluid/tests/unittests/test_fused_attention_op_api.py
+77
-33
未找到文件。
paddle/fluid/operators/fused/fmha_ref.h
浏览文件 @
472dcca4
...
...
@@ -69,7 +69,7 @@ class FMHARef {
~
FMHARef
()
{}
void
ComputeForward
(
const
Tensor
&
qkv_input_tensor
,
const
Tensor
&
src_mask_tensor
,
const
Tensor
*
src_mask_tensor
,
Tensor
*
transpose_2_out_tensor
,
Tensor
*
qk_out_tensor
,
Tensor
*
src_mask_out_tensor
,
Tensor
*
softmax_out_tensor
,
Tensor
*
dropout_mask_out_tensor
,
...
...
@@ -111,17 +111,17 @@ class FMHARef {
blas
.
BatchedGEMM
(
transA
,
transB
,
gemm_m
,
gemm_n
,
gemm_k
,
alpha
,
q_ptr
,
k_ptr
,
beta
,
qk_out_data
,
gemm_batch_size
,
stride_a
,
stride_b
);
int
softmax_axis
=
-
1
;
if
(
src_mask_tensor
!=
nullptr
)
{
std
::
vector
<
const
Tensor
*>
ins
;
std
::
vector
<
Tensor
*>
outs
;
ins
.
emplace_back
(
qk_out_tensor
);
ins
.
emplace_back
(
&
src_mask_tensor
);
ins
.
emplace_back
(
src_mask_tensor
);
outs
.
emplace_back
(
src_mask_out_tensor
);
int
elewise_add_axis
=
-
1
;
int
softmax_axis
=
-
1
;
if
(
&
src_mask_tensor
!=
nullptr
)
{
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
dev_ctx_
,
ins
,
&
outs
,
elewise_add_axis
,
AddFunctor
<
T
>
());
SoftmaxForwardCUDAKernelDriver
<
T
>
(
dev_ctx_
,
*
src_mask_out_tensor
,
softmax_axis
,
softmax_out_tensor
);
}
else
{
...
...
@@ -165,7 +165,7 @@ class FMHARef {
}
void
ComputeBackward
(
const
Tensor
&
transpose_2_out_tensor
,
const
Tensor
&
src_mask_tensor
,
const
Tensor
&
transpose_2_out_tensor
,
const
Tensor
*
src_mask_tensor
,
const
Tensor
&
softmax_out_tensor
,
const
Tensor
&
dropout_mask_out_tensor
,
const
Tensor
&
dropout_out_tensor
,
const
Tensor
&
qk_out_tensor
,
const
Tensor
&
src_mask_out_tensor
,
const
Tensor
&
fmha_out_grad_tensor
,
...
...
@@ -249,7 +249,7 @@ class FMHARef {
softmax_out_grad_tensor
);
}
if
(
&
src_mask_tensor
!=
nullptr
)
{
if
(
src_mask_tensor
!=
nullptr
)
{
SoftmaxBackwardCUDAKernelDriver
<
T
>
(
dev_ctx_
,
softmax_out_tensor
,
*
softmax_out_grad_tensor
,
softmax_axis
,
src_mask_out_grad_tensor
);
...
...
paddle/fluid/operators/fused/fused_attention_op.cc
浏览文件 @
472dcca4
...
...
@@ -27,8 +27,6 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SrcMask"
),
"Input"
,
"SrcMask"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVW"
),
"Input"
,
"QKVW"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVBias"
),
"Input"
,
"QKVBias"
,
"FusedAttentionOp"
);
...
...
@@ -57,8 +55,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKTVOut"
),
"Output"
,
"QKTVOut"
,
"FusedAttentionOp"
);
if
(
ctx
->
HasInput
(
"SrcMask"
))
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SrcMaskOut"
),
"Output"
,
"SrcMaskOut"
,
"FusedAttentionOp"
);
}
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SoftmaxOut"
),
"Output"
,
"SoftmaxOut"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"AttnDropoutMaskOut"
),
"Output"
,
...
...
@@ -119,7 +120,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
{
y_dim
[
0
],
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
y_dim
[
2
]});
// [batch, num_head, seq_len, seq_len]
ctx
->
SetOutputDim
(
"QKOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
x_dim
[
1
]});
if
(
ctx
->
HasInput
(
"SrcMask"
))
{
ctx
->
SetOutputDim
(
"SrcMaskOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
x_dim
[
1
]});
}
// the same as QKOut's shape.
ctx
->
SetOutputDim
(
"AttnDropoutOut"
,
{
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
x_dim
[
1
]});
...
...
@@ -320,7 +324,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
{
out = transpose(out, perm=[2, 0, 3, 1, 4]);
out = q * k^t;
out = attn_ma
r
k + out;
out = attn_ma
s
k + out;
out = softmax(out);
out = dropout(out);
out = out * v;
...
...
@@ -368,8 +372,6 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVBias"
),
"Input"
,
"QKVBias"
,
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SrcMask"
),
"Input"
,
"SrcMask"
,
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearW"
),
"Input"
,
"OutLinearW"
,
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearBias"
),
"Input"
,
"OutLinearBias"
,
...
...
@@ -413,8 +415,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx
->
GetInputDim
(
"SoftmaxOut"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"AttnDropoutOut"
),
ctx
->
GetInputDim
(
"AttnDropoutOut"
));
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"SrcMaskOut"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"SrcMaskOut"
),
ctx
->
GetInputDim
(
"SrcMaskOut"
));
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVOut"
),
ctx
->
GetInputDim
(
"QKVOut"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVBiasOut"
),
...
...
@@ -448,7 +453,14 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op
->
SetInput
(
"QKVW"
,
this
->
Input
(
"QKVW"
));
op
->
SetInput
(
"QKVBias"
,
this
->
Input
(
"QKVBias"
));
if
(
this
->
HasInput
(
"SrcMask"
))
{
op
->
SetInput
(
"SrcMask"
,
this
->
Input
(
"SrcMask"
));
op
->
SetInput
(
"SrcMaskOut"
,
this
->
Output
(
"SrcMaskOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"SrcMaskOut"
),
this
->
OutputGrad
(
"SrcMaskOut"
));
}
op
->
SetInput
(
"OutLinearW"
,
this
->
Input
(
"OutLinearW"
));
op
->
SetInput
(
"OutLinearBias"
,
this
->
Input
(
"OutLinearBias"
));
...
...
@@ -508,7 +520,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op
->
SetInput
(
"SoftmaxOut"
,
this
->
Output
(
"SoftmaxOut"
));
op
->
SetInput
(
"AttnDropoutMaskOut"
,
this
->
Output
(
"AttnDropoutMaskOut"
));
op
->
SetInput
(
"AttnDropoutOut"
,
this
->
Output
(
"AttnDropoutOut"
));
op
->
SetInput
(
"SrcMaskOut"
,
this
->
Output
(
"SrcMaskOut"
));
op
->
SetInput
(
"FMHAOut"
,
this
->
Output
(
"FMHAOut"
));
op
->
SetInput
(
"OutLinearOut"
,
this
->
Output
(
"OutLinearOut"
));
...
...
@@ -538,8 +550,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this
->
OutputGrad
(
"SoftmaxOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"AttnDropoutOut"
),
this
->
OutputGrad
(
"AttnDropoutOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"SrcMaskOut"
),
this
->
OutputGrad
(
"SrcMaskOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"FMHAOut"
),
this
->
OutputGrad
(
"FMHAOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"BiasDropoutResidualOut"
),
...
...
paddle/fluid/operators/fused/fused_attention_op.cu
浏览文件 @
472dcca4
...
...
@@ -114,7 +114,9 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
transpose_out_2
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
qk_out_data
=
qk_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
qktv_out_data
=
qktv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
src_mask_out_data
=
src_mask_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
src_mask_out_data
=
(
src_mask
==
nullptr
)
?
nullptr
:
src_mask_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
softmax_out_data
=
softmax_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
attn_dropout_mask_out_data
=
attn_dropout_mask_out
->
mutable_data
<
uint8_t
>
(
ctx
.
GetPlace
());
...
...
@@ -184,10 +186,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
qkv_compute
.
ComputeForward
(
qkv_weight_data
,
x_data
,
qkv_bias_data
,
qkv_out_data
,
qkv_bias_out_data
);
}
fmha_ref_compute
.
ComputeForward
(
*
qkv_bias_out
,
*
src_mask
,
transpose_out_2
,
fmha_ref_compute
.
ComputeForward
(
*
qkv_bias_out
,
src_mask
,
transpose_out_2
,
qk_out
,
src_mask_out
,
softmax_out
,
attn_dropout_mask_out
,
attn_dropout_out
,
qktv_out
,
fmha_out
);
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim]
...
...
@@ -265,7 +268,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
qk_out_data
=
qk_out
->
data
<
T
>
();
auto
*
qktv_out_data
=
qktv_out
->
data
<
T
>
();
auto
*
softmax_out_data
=
softmax_out
->
data
<
T
>
();
auto
*
src_mask_out_data
=
src_mask_out
->
data
<
T
>
();
auto
*
src_mask_out_data
=
(
src_mask
==
nullptr
)
?
nullptr
:
src_mask_out
->
data
<
T
>
();
auto
*
out_linear_out_data
=
out_linear_out
->
data
<
T
>
();
auto
*
ln_2_mean_data
=
ln_2_mean
->
data
<
U
>
();
auto
*
ln_2_var_data
=
ln_2_var
->
data
<
U
>
();
...
...
@@ -302,7 +306,9 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
d_softmax_out_data
=
d_softmax_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_attn_dropout_out_data
=
d_attn_dropout_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_src_mask_out_data
=
d_src_mask_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_src_mask_out_data
=
(
src_mask
==
nullptr
)
?
nullptr
:
d_src_mask_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_fmha_out_data
=
d_fmha_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_out_linear_out_data
=
d_out_linear_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
@@ -386,7 +392,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_out_data
,
d_fmha_out_data
,
d_out_linear_weight_data
,
nullptr
);
fmha_ref_compute
.
ComputeBackward
(
*
transpose_out_2
,
*
src_mask
,
*
softmax_out
,
*
attn_dropout_mask_out
,
*
transpose_out_2
,
src_mask
,
*
softmax_out
,
*
attn_dropout_mask_out
,
*
attn_dropout_out
,
*
qk_out
,
*
src_mask_out
,
*
d_fmha_out
,
d_qktv_out
,
d_attn_dropout_out
,
d_softmax_out
,
d_src_mask_out
,
d_qk_out
,
d_transpose_out_2
,
nullptr
,
d_qkv_bias_out
);
...
...
python/paddle/fluid/tests/unittests/test_fused_attention_op.py
浏览文件 @
472dcca4
...
...
@@ -66,6 +66,7 @@ class TestFusedAttentionOp(OpTest):
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
False
self
.
has_attn_mask
=
True
self
.
training
=
True
self
.
batch_size
=
8
...
...
@@ -84,6 +85,7 @@ class TestFusedAttentionOp(OpTest):
def
generate_input_data
(
self
):
self
.
query
=
np
.
random
.
rand
(
self
.
batch_size
,
self
.
query_length
,
self
.
embed_dim
).
astype
(
self
.
x_type
)
if
self
.
has_attn_mask
:
self
.
attn_mask
=
np
.
ones
(
(
self
.
batch_size
,
self
.
num_heads
,
self
.
query_length
,
self
.
key_length
),
...
...
@@ -93,7 +95,10 @@ class TestFusedAttentionOp(OpTest):
elif
self
.
attn_mask_type
==
np
.
float64
:
self
.
attn_mask
=
(
np
.
tril
(
self
.
attn_mask
)
-
1.0
)
*
1e9
else
:
raise
ValueError
(
"'attn_mask_type' should be 'int64' or 'float64'."
)
raise
ValueError
(
"'attn_mask_type' should be 'int64' or 'float64'."
)
else
:
self
.
attn_mask
=
None
self
.
key
,
self
.
value
=
self
.
query
,
self
.
query
self
.
dout
=
np
.
random
.
random
((
self
.
batch_size
,
self
.
query_length
,
...
...
@@ -102,7 +107,10 @@ class TestFusedAttentionOp(OpTest):
def
GetBaselineOut
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
tensor_query
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
if
self
.
has_attn_mask
:
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
else
:
attn_mask
=
None
residual
=
tensor_query
ln1_out
=
tensor_query
...
...
@@ -187,7 +195,10 @@ class TestFusedAttentionOp(OpTest):
qkv_bias
=
qkv_bias
.
reshape
((
3
,
self
.
num_heads
,
self
.
head_dim
))
x
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
if
self
.
has_attn_mask
:
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
else
:
attn_mask
=
None
qkv_weight_tensor
=
paddle
.
to_tensor
(
qkv_weight
,
stop_gradient
=
False
)
qkv_bias_tensor
=
paddle
.
to_tensor
(
qkv_bias
,
stop_gradient
=
False
)
epsilon
=
1e-05
...
...
@@ -218,6 +229,37 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
True
self
.
has_attn_mask
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
weight_attr
=
None
self
.
bias_attr
=
None
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
self
.
query_length
,
self
.
query_length
def
test_fused_attention_op
(
self
):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-1
)
np
.
testing
.
assert_allclose
(
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-1
)
class
TestFusedAttentionOpNoneAttnMask
(
TestFusedAttentionOp
):
def
config
(
self
):
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
True
self
.
has_attn_mask
=
False
self
.
training
=
True
self
.
batch_size
=
8
...
...
@@ -247,6 +289,7 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp):
self
.
x_type
=
np
.
float16
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
False
self
.
has_attn_mask
=
True
self
.
training
=
True
self
.
batch_size
=
8
...
...
python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py
浏览文件 @
472dcca4
...
...
@@ -152,6 +152,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
True
self
.
has_attn_mask
=
True
self
.
training
=
True
self
.
need_weight
=
False
...
...
@@ -172,6 +173,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
def
generate_input_data
(
self
):
self
.
query
=
np
.
random
.
rand
(
self
.
batch_size
,
self
.
query_length
,
self
.
embed_dim
).
astype
(
self
.
x_type
)
if
self
.
has_attn_mask
:
self
.
attn_mask
=
np
.
ones
(
(
self
.
batch_size
,
self
.
num_heads
,
self
.
query_length
,
self
.
key_length
),
...
...
@@ -181,10 +183,17 @@ class TestFusedAttentionAPI(unittest.TestCase):
elif
self
.
attn_mask_type
==
np
.
float64
:
self
.
attn_mask
=
(
np
.
tril
(
self
.
attn_mask
)
-
1.0
)
*
1e9
else
:
raise
ValueError
(
"'attn_mask_type' should be 'int64' or 'float64'."
)
raise
ValueError
(
"'attn_mask_type' should be 'int64' or 'float64'."
)
else
:
self
.
attn_mask
=
None
self
.
key
,
self
.
value
=
self
.
query
,
self
.
query
def
run_imperative
(
self
):
if
self
.
has_attn_mask
:
attn_mask_tensor
=
paddle
.
to_tensor
(
self
.
attn_mask
)
else
:
attn_mask_tensor
=
None
fused_attn
=
FusedMultiHeadAttention
(
self
.
embed_dim
,
self
.
num_heads
,
self
.
dropout_prob
,
self
.
attn_dropout_prob
,
self
.
kdim
,
self
.
vdim
,
self
.
pre_layer_norm
,
...
...
@@ -192,7 +201,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
out
=
fused_attn
(
paddle
.
to_tensor
(
self
.
query
),
paddle
.
to_tensor
(
self
.
query
),
paddle
.
to_tensor
(
self
.
query
),
paddle
.
to_tensor
(
self
.
attn_mask
)
)
paddle
.
to_tensor
(
self
.
query
),
attn_mask_tensor
)
ref_out
=
compute_reference
(
self
.
pre_layer_norm
,
self
.
query
,
self
.
attn_mask
,
fused_attn
.
pre_ln_scale
.
numpy
(),
...
...
@@ -203,7 +212,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
fused_attn
.
qkv_bias
.
numpy
(),
fused_attn
.
linear_weight
.
numpy
(),
fused_attn
.
linear_bias
.
numpy
())
self
.
assertTrue
(
np
.
allclose
(
ref_out
,
out
,
rtol
=
1e-5
,
atol
=
1e-5
)
)
np
.
testing
.
assert_allclose
(
ref_out
,
out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-5
)
def
run_static
(
self
):
fused_attn
=
FusedMultiHeadAttention
(
...
...
@@ -215,6 +224,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
name
=
'X'
,
shape
=
[
self
.
batch_size
,
self
.
query_length
,
self
.
embed_dim
],
dtype
=
self
.
x_type
)
if
self
.
has_attn_mask
:
attn_mask
=
paddle
.
static
.
data
(
name
=
'SrcMask'
,
shape
=
[
...
...
@@ -223,10 +233,13 @@ class TestFusedAttentionAPI(unittest.TestCase):
],
dtype
=
self
.
attn_mask_type
)
final_out
=
fused_attn
(
x
,
x
,
x
,
attn_mask
)
else
:
final_out
=
fused_attn
(
x
,
x
,
x
)
place
=
paddle
.
CUDAPlace
(
0
)
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
paddle
.
static
.
default_startup_program
())
if
self
.
has_attn_mask
:
out
,
qkv_weight
,
qkv_bias
,
out_linear_weight
,
linear_bias
,
ln_scale
,
ln_bias
,
ln_2_scale
,
ln_2_bias
=
exe
.
run
(
paddle
.
static
.
default_main_program
(),
feed
=
{
"X"
:
self
.
query
,
...
...
@@ -237,7 +250,16 @@ class TestFusedAttentionAPI(unittest.TestCase):
fused_attn
.
pre_ln_scale
,
fused_attn
.
pre_ln_bias
,
fused_attn
.
ln_scale
,
fused_attn
.
ln_bias
])
else
:
out
,
qkv_weight
,
qkv_bias
,
out_linear_weight
,
linear_bias
,
ln_scale
,
ln_bias
,
ln_2_scale
,
ln_2_bias
=
exe
.
run
(
paddle
.
static
.
default_main_program
(),
feed
=
{
"X"
:
self
.
query
,
},
fetch_list
=
[
final_out
,
fused_attn
.
qkv_weight
,
fused_attn
.
qkv_bias
,
fused_attn
.
linear_weight
,
fused_attn
.
linear_bias
,
fused_attn
.
pre_ln_scale
,
fused_attn
.
pre_ln_bias
,
fused_attn
.
ln_scale
,
fused_attn
.
ln_bias
])
return
out
,
qkv_weight
,
qkv_bias
,
out_linear_weight
,
linear_bias
,
ln_scale
,
ln_bias
,
ln_2_scale
,
ln_2_bias
def
test_static_api
(
self
):
...
...
@@ -249,14 +271,36 @@ class TestFusedAttentionAPI(unittest.TestCase):
self
.
attn_mask
,
ln_scale
,
ln_bias
,
ln_2_scale
,
ln_2_bias
,
qkv_weight
,
qkv_bias
,
linear_weight
,
linear_bias
)
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
ref_out
),
np
.
array
(
out
),
rtol
=
1e-5
,
atol
=
1e-5
))
np
.
testing
.
assert_allclose
(
ref_out
,
out
,
rtol
=
1e-5
,
atol
=
1e-5
)
def
test_dynamic_api
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
self
.
run_imperative
()
class
TestFusedAttentionAPINoneAttnMask
(
TestFusedAttentionAPI
):
def
config
(
self
):
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
True
self
.
has_attn_mask
=
False
self
.
training
=
True
self
.
need_weight
=
False
self
.
batch_size
=
1
self
.
query_length
=
2
self
.
head_dim
=
2
self
.
num_heads
=
2
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
weight_attr
=
None
self
.
bias_attr
=
None
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
self
.
query_length
,
self
.
query_length
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录