Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ad44a40c
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ad44a40c
编写于
11月 10, 2021
作者:
L
Li Min
提交者:
GitHub
11月 10, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix fused_attention_op scope. (#37065)
att, bug fix
上级
48d53cfc
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
135 addition
and
108 deletion
+135
-108
paddle/fluid/operators/fused/fused_attention_op.cc
paddle/fluid/operators/fused/fused_attention_op.cc
+70
-58
paddle/fluid/operators/fused/fused_attention_op.cu
paddle/fluid/operators/fused/fused_attention_op.cu
+52
-39
python/paddle/fluid/tests/unittests/test_fused_attention_op.py
...n/paddle/fluid/tests/unittests/test_fused_attention_op.py
+8
-8
python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py
...ddle/fluid/tests/unittests/test_fused_attention_op_api.py
+5
-3
未找到文件。
paddle/fluid/operators/fused/fused_attention_op.cc
浏览文件 @
ad44a40c
...
@@ -42,6 +42,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -42,6 +42,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp"
);
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnOut"
),
"Output"
,
"LnOut"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnOut"
),
"Output"
,
"LnOut"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
}
else
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Ln2Mean"
),
"Output"
,
"Ln2Mean"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Ln2Variance"
),
"Output"
,
"Ln2Variance"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BiasDropoutResidualOut"
),
"Output"
,
"BiasDropoutResidualOut"
,
"FusedAttentionOp"
);
}
}
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
...
@@ -70,12 +77,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -70,12 +77,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp"
);
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutLinearOut"
),
"Output"
,
"OutLinearOut"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutLinearOut"
),
"Output"
,
"OutLinearOut"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Ln2Mean"
),
"Output"
,
"Ln2Mean"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Ln2Variance"
),
"Output"
,
"Ln2Variance"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BiasDropoutResidualOut"
),
"Output"
,
"BiasDropoutResidualOut"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"DropoutMaskOut"
),
"Output"
,
"DropoutMaskOut"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"DropoutMaskOut"
),
"Output"
,
"DropoutMaskOut"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Y"
),
"Output"
,
"Y"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Y"
),
"Output"
,
"Y"
,
"FusedAttentionOp"
);
...
@@ -109,6 +111,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -109,6 +111,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"LnMean"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnMean"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnVariance"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnVariance"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnOut"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"LnOut"
,
ctx
->
GetInputDim
(
"X"
));
}
else
{
ctx
->
SetOutputDim
(
"Ln2Mean"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"Ln2Variance"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"BiasDropoutResidualOut"
,
ctx
->
GetInputDim
(
"X"
));
}
}
// [batch_size, seq_len, 3, num_head, head_size]
// [batch_size, seq_len, 3, num_head, head_size]
ctx
->
SetOutputDim
(
"QKVOut"
,
ctx
->
SetOutputDim
(
"QKVOut"
,
...
@@ -138,12 +144,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -138,12 +144,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"FMHAOut"
,
{
x_dim
[
0
],
x_dim
[
1
],
y_dim
[
1
],
y_dim
[
2
]});
ctx
->
SetOutputDim
(
"FMHAOut"
,
{
x_dim
[
0
],
x_dim
[
1
],
y_dim
[
1
],
y_dim
[
2
]});
ctx
->
SetOutputDim
(
"OutLinearOut"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"OutLinearOut"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Ln2Mean"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"Ln2Variance"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"dropout_is_test"
)
==
false
)
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"dropout_is_test"
)
==
false
)
{
ctx
->
SetOutputDim
(
"DropoutMaskOut"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"DropoutMaskOut"
,
ctx
->
GetInputDim
(
"X"
));
}
}
ctx
->
SetOutputDim
(
"BiasDropoutResidualOut"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
}
}
...
@@ -314,25 +318,28 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -314,25 +318,28 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
});
});
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Add fused attention op whose logic is as follows:
Add fused attention op whose logic is as follows:
// @input: [batch_size, seq_len, 3, num_head, head_dim]
// @input: [batch_size, seq_len, 3, num_head, head_dim]
// @final_out: [batch_size, seq_len, num_heads, head_dim]
// @final_out: [batch_size, seq_len, num_heads, head_dim]
if (pre_layernorm)
if (pre_layernorm)
out = layer_norm(input);
out = layer_norm(input);
out = compute_qkv(out) + bias;
out = compute_qkv(out) + bias;
// fmha module
// fmha module
{
{
out = transpose(out, perm=[2, 0, 3, 1, 4]);
out = transpose(out, perm=[2, 0, 3, 1, 4]);
out = q * k^t;
out = q * k^t;
out = attn_mask + out;
out = attn_mask + out;
out = softmax(out);
out = softmax(out);
out = dropout(out);
out = dropout(out);
out = out * v;
out = out * v;
out = transpose(out, perm=[0, 2, 1, 3]);
out = transpose(out, perm=[0, 2, 1, 3]);
}
}
out = out_linear(out);
out = out_linear(out);
final_out = layer_norm(residual + dropout(bias + out));
if (pre_layernorm)
final_out = residual + dropout(bias + out);
else
final_out = layer_norm(residual + dropout(bias + out));
)DOC"
);
)DOC"
);
}
}
};
};
...
@@ -347,20 +354,20 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
...
@@ -347,20 +354,20 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"GradOp is only callable when attn_dropout_is_test is false"
));
"GradOp is only callable when attn_dropout_is_test is false"
));
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln2Mean"
),
"Input"
,
"Ln2Mean"
,
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
false
)
{
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln2Mean"
),
"Input"
,
"Ln2Mean"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln2Variance"
),
"Input"
,
"Ln2Variance"
,
"FusedAttentionGrad"
);
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln2Variance"
),
"Input"
,
"Ln2Variance"
,
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Ln2Scale"
)))
{
"FusedAttentionGrad"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Ln2Scale"
),
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Ln2Scale"
)))
{
ctx
->
GetInputDim
(
"Ln2Scale"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Ln2Scale"
),
}
ctx
->
GetInputDim
(
"Ln2Scale"
));
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Ln2Bias"
)))
{
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Ln2Bias"
),
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Ln2Bias"
)))
{
ctx
->
GetInputDim
(
"Ln2Bias"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Ln2Bias"
),
}
ctx
->
GetInputDim
(
"Ln2Bias"
));
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FusedAttentionGrad"
);
}
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
}
else
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnMean"
),
"Input"
,
"LnMean"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnMean"
),
"Input"
,
"LnMean"
,
"FusedAttentionGrad"
);
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnVariance"
),
"Input"
,
"LnVariance"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnVariance"
),
"Input"
,
"LnVariance"
,
...
@@ -368,6 +375,8 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
...
@@ -368,6 +375,8 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnOut"
),
"Input"
,
"LnOut"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnOut"
),
"Input"
,
"LnOut"
,
"FusedAttentionGrad"
);
"FusedAttentionGrad"
);
}
}
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVW"
),
"Input"
,
"QKVW"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVW"
),
"Input"
,
"QKVW"
,
"FusedAttentionGrad"
);
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVBias"
),
"Input"
,
"QKVBias"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVBias"
),
"Input"
,
"QKVBias"
,
...
@@ -402,6 +411,9 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
...
@@ -402,6 +411,9 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnOut"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnOut"
),
ctx
->
GetInputDim
(
"LnOut"
));
ctx
->
GetInputDim
(
"LnOut"
));
}
else
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"BiasDropoutResidualOut"
),
ctx
->
GetInputDim
(
"BiasDropoutResidualOut"
));
}
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"FMHAOut"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"FMHAOut"
),
ctx
->
GetInputDim
(
"FMHAOut"
));
ctx
->
GetInputDim
(
"FMHAOut"
));
...
@@ -426,8 +438,6 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
...
@@ -426,8 +438,6 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx
->
GetInputDim
(
"QKVBiasOut"
));
ctx
->
GetInputDim
(
"QKVBiasOut"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"OutLinearOut"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"OutLinearOut"
),
ctx
->
GetInputDim
(
"OutLinearOut"
));
ctx
->
GetInputDim
(
"OutLinearOut"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"BiasDropoutResidualOut"
),
ctx
->
GetInputDim
(
"BiasDropoutResidualOut"
));
}
}
protected:
protected:
...
@@ -478,17 +488,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -478,17 +488,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op
->
SetOutput
(
framework
::
GradVarName
(
"LnBias"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"LnBias"
),
this
->
InputGrad
(
"LnBias"
));
this
->
InputGrad
(
"LnBias"
));
}
}
}
}
else
{
if
(
this
->
HasInput
(
"Ln2Scale"
))
{
if
(
this
->
HasInput
(
"Ln2Scale"
))
{
op
->
SetInput
(
"Ln2Scale"
,
this
->
Input
(
"Ln2Scale"
));
op
->
SetInput
(
"Ln2Scale"
,
this
->
Input
(
"Ln2Scale"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Scale"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Scale"
),
this
->
InputGrad
(
"Ln2Scale"
));
this
->
InputGrad
(
"Ln2Scale"
));
}
}
if
(
this
->
HasInput
(
"Ln2Bias"
))
{
if
(
this
->
HasInput
(
"Ln2Bias"
))
{
op
->
SetInput
(
"Ln2Bias"
,
this
->
Input
(
"Ln2Bias"
));
op
->
SetInput
(
"Ln2Bias"
,
this
->
Input
(
"Ln2Bias"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Bias"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Bias"
),
this
->
InputGrad
(
"Ln2Bias"
));
this
->
InputGrad
(
"Ln2Bias"
));
}
}
}
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
...
@@ -511,6 +521,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -511,6 +521,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
if
(
this
->
HasOutput
(
"LnVariance"
))
{
if
(
this
->
HasOutput
(
"LnVariance"
))
{
op
->
SetInput
(
"LnVariance"
,
this
->
Output
(
"LnVariance"
));
op
->
SetInput
(
"LnVariance"
,
this
->
Output
(
"LnVariance"
));
}
}
}
else
{
op
->
SetInput
(
"Ln2Mean"
,
this
->
Output
(
"Ln2Mean"
));
op
->
SetInput
(
"Ln2Variance"
,
this
->
Output
(
"Ln2Variance"
));
op
->
SetInput
(
"BiasDropoutResidualOut"
,
this
->
Output
(
"BiasDropoutResidualOut"
));
}
}
op
->
SetInput
(
"QKVOut"
,
this
->
Output
(
"QKVOut"
));
op
->
SetInput
(
"QKVOut"
,
this
->
Output
(
"QKVOut"
));
op
->
SetInput
(
"QKVBiasOut"
,
this
->
Output
(
"QKVBiasOut"
));
op
->
SetInput
(
"QKVBiasOut"
,
this
->
Output
(
"QKVBiasOut"
));
...
@@ -523,12 +538,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -523,12 +538,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op
->
SetInput
(
"FMHAOut"
,
this
->
Output
(
"FMHAOut"
));
op
->
SetInput
(
"FMHAOut"
,
this
->
Output
(
"FMHAOut"
));
op
->
SetInput
(
"OutLinearOut"
,
this
->
Output
(
"OutLinearOut"
));
op
->
SetInput
(
"OutLinearOut"
,
this
->
Output
(
"OutLinearOut"
));
op
->
SetInput
(
"Ln2Mean"
,
this
->
Output
(
"Ln2Mean"
));
op
->
SetInput
(
"Ln2Variance"
,
this
->
Output
(
"Ln2Variance"
));
op
->
SetInput
(
"DropoutMaskOut"
,
this
->
Output
(
"DropoutMaskOut"
));
op
->
SetInput
(
"DropoutMaskOut"
,
this
->
Output
(
"DropoutMaskOut"
));
op
->
SetInput
(
"BiasDropoutResidualOut"
,
this
->
Output
(
"BiasDropoutResidualOut"
));
op
->
SetInput
(
"QKVOut"
,
this
->
Output
(
"QKVOut"
));
op
->
SetInput
(
"QKVOut"
,
this
->
Output
(
"QKVOut"
));
// backward outputs: dinput
// backward outputs: dinput
...
@@ -537,7 +547,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -537,7 +547,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op
->
SetOutput
(
framework
::
GradVarName
(
"LnOut"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"LnOut"
),
this
->
OutputGrad
(
"LnOut"
));
this
->
OutputGrad
(
"LnOut"
));
}
}
}
else
{
op
->
SetOutput
(
framework
::
GradVarName
(
"BiasDropoutResidualOut"
),
this
->
OutputGrad
(
"BiasDropoutResidualOut"
));
}
}
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVOut"
),
this
->
OutputGrad
(
"QKVOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVOut"
),
this
->
OutputGrad
(
"QKVOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVBiasOut"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVBiasOut"
),
this
->
OutputGrad
(
"QKVBiasOut"
));
this
->
OutputGrad
(
"QKVBiasOut"
));
...
@@ -553,8 +567,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -553,8 +567,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op
->
SetOutput
(
framework
::
GradVarName
(
"FMHAOut"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"FMHAOut"
),
this
->
OutputGrad
(
"FMHAOut"
));
this
->
OutputGrad
(
"FMHAOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"BiasDropoutResidualOut"
),
this
->
OutputGrad
(
"BiasDropoutResidualOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"OutLinearOut"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"OutLinearOut"
),
this
->
OutputGrad
(
"OutLinearOut"
));
this
->
OutputGrad
(
"OutLinearOut"
));
}
}
...
...
paddle/fluid/operators/fused/fused_attention_op.cu
浏览文件 @
ad44a40c
...
@@ -95,15 +95,6 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -95,15 +95,6 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
const
auto
qkv_w_dims
=
qkv_weight
->
dims
();
const
auto
qkv_w_dims
=
qkv_weight
->
dims
();
auto
*
x_data
=
input_x
->
data
<
T
>
();
auto
*
x_data
=
input_x
->
data
<
T
>
();
auto
*
ln_scale_data
=
(
ln_scale
==
nullptr
?
nullptr
:
ln_scale
->
data
<
U
>
());
auto
*
ln_bias_data
=
(
ln_bias
==
nullptr
?
nullptr
:
ln_bias
->
data
<
U
>
());
auto
*
ln_mean_data
=
pre_layer_norm
?
ln_mean
->
mutable_data
<
U
>
(
ctx
.
GetPlace
())
:
nullptr
;
auto
*
ln_var_data
=
pre_layer_norm
?
ln_var
->
mutable_data
<
U
>
(
ctx
.
GetPlace
())
:
nullptr
;
auto
*
ln_out_data
=
pre_layer_norm
?
ln_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())
:
nullptr
;
auto
*
qkv_weight_data
=
qkv_weight
->
data
<
T
>
();
auto
*
qkv_weight_data
=
qkv_weight
->
data
<
T
>
();
auto
*
qkv_bias_data
=
qkv_bias
->
data
<
T
>
();
auto
*
qkv_bias_data
=
qkv_bias
->
data
<
T
>
();
auto
*
qkv_out_data
=
qkv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
qkv_out_data
=
qkv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
@@ -130,16 +121,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -130,16 +121,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto
*
out_linear_out_data
=
out_linear_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
out_linear_out_data
=
out_linear_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// get data ptr for bias+dropout+residual+layernorm
// get data ptr for bias+dropout+residual+layernorm
auto
*
ln_scale_2_data
=
(
ln_scale_2
==
nullptr
?
nullptr
:
ln_scale_2
->
data
<
U
>
());
auto
*
ln_bias_2_data
=
(
ln_bias_2
==
nullptr
?
nullptr
:
ln_bias_2
->
data
<
U
>
());
auto
*
dropout_mask_out_data
=
auto
*
dropout_mask_out_data
=
dropout_mask_out
->
mutable_data
<
uint8_t
>
(
ctx
.
GetPlace
());
dropout_mask_out
->
mutable_data
<
uint8_t
>
(
ctx
.
GetPlace
());
auto
*
bias_dropout_residual_out_data
=
bias_dropout_residual_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
ln_mean_2_data
=
ln_mean_2
->
mutable_data
<
U
>
(
ctx
.
GetPlace
());
auto
*
ln_var_2_data
=
ln_var_2
->
mutable_data
<
U
>
(
ctx
.
GetPlace
());
auto
*
final_out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
final_out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
batch_size
=
input_x_dims
[
0
];
int
batch_size
=
input_x_dims
[
0
];
...
@@ -178,6 +161,13 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -178,6 +161,13 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
ln_epsilon
);
ln_epsilon
);
if
(
pre_layer_norm
)
{
if
(
pre_layer_norm
)
{
auto
*
ln_scale_data
=
(
ln_scale
==
nullptr
?
nullptr
:
ln_scale
->
data
<
U
>
());
auto
*
ln_bias_data
=
(
ln_bias
==
nullptr
?
nullptr
:
ln_bias
->
data
<
U
>
());
auto
*
ln_mean_data
=
ln_mean
->
mutable_data
<
U
>
(
ctx
.
GetPlace
());
auto
*
ln_var_data
=
ln_var
->
mutable_data
<
U
>
(
ctx
.
GetPlace
());
auto
*
ln_out_data
=
ln_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
layer_norm_compute
.
ComputeForward
(
x_data
,
ln_scale_data
,
ln_bias_data
,
layer_norm_compute
.
ComputeForward
(
x_data
,
ln_scale_data
,
ln_bias_data
,
ln_out_data
,
ln_mean_data
,
ln_var_data
);
ln_out_data
,
ln_mean_data
,
ln_var_data
);
qkv_compute
.
ComputeForward
(
qkv_weight_data
,
ln_out_data
,
qkv_bias_data
,
qkv_compute
.
ComputeForward
(
qkv_weight_data
,
ln_out_data
,
qkv_bias_data
,
...
@@ -196,12 +186,27 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -196,12 +186,27 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// out_linear_out: [batch_size, seq_len, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute
.
ComputeForward
(
out_linear_weight_data
,
fmha_out_data
,
out_linear_compute
.
ComputeForward
(
out_linear_weight_data
,
fmha_out_data
,
nullptr
,
out_linear_out_data
,
nullptr
);
nullptr
,
out_linear_out_data
,
nullptr
);
// output = layernorm(residual + dropout(input + bias))
if
(
pre_layer_norm
)
{
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBias
(
// output = (residual + dropout(input + bias))
ctx
.
cuda_device_context
(),
out_linear_out_data
,
x_data
,
fused_dropout_layernorm_helper
.
ResidualDropoutBias
(
out_linear_bias_data
,
ln_scale_2_data
,
ln_bias_2_data
,
ctx
.
cuda_device_context
(),
out_linear_out_data
,
x_data
,
bias_dropout_residual_out_data
,
dropout_mask_out_data
,
final_out_data
,
out_linear_bias_data
,
final_out_data
,
dropout_mask_out_data
);
ln_mean_2_data
,
ln_var_2_data
);
}
else
{
auto
*
ln_scale_2_data
=
(
ln_scale_2
==
nullptr
?
nullptr
:
ln_scale_2
->
data
<
U
>
());
auto
*
ln_bias_2_data
=
(
ln_bias_2
==
nullptr
?
nullptr
:
ln_bias_2
->
data
<
U
>
());
auto
*
bias_dropout_residual_out_data
=
bias_dropout_residual_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
ln_mean_2_data
=
ln_mean_2
->
mutable_data
<
U
>
(
ctx
.
GetPlace
());
auto
*
ln_var_2_data
=
ln_var_2
->
mutable_data
<
U
>
(
ctx
.
GetPlace
());
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBias
(
ctx
.
cuda_device_context
(),
out_linear_out_data
,
x_data
,
out_linear_bias_data
,
ln_scale_2_data
,
ln_bias_2_data
,
bias_dropout_residual_out_data
,
dropout_mask_out_data
,
final_out_data
,
ln_mean_2_data
,
ln_var_2_data
);
}
}
}
};
};
...
@@ -271,10 +276,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -271,10 +276,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
src_mask_out_data
=
auto
*
src_mask_out_data
=
(
src_mask
==
nullptr
)
?
nullptr
:
src_mask_out
->
data
<
T
>
();
(
src_mask
==
nullptr
)
?
nullptr
:
src_mask_out
->
data
<
T
>
();
auto
*
out_linear_out_data
=
out_linear_out
->
data
<
T
>
();
auto
*
out_linear_out_data
=
out_linear_out
->
data
<
T
>
();
auto
*
ln_2_mean_data
=
ln_2_mean
->
data
<
U
>
();
auto
*
ln_2_var_data
=
ln_2_var
->
data
<
U
>
();
auto
*
dropout_mask_out_data
=
dropout_mask_out
->
data
<
uint8_t
>
();
auto
*
dropout_mask_out_data
=
dropout_mask_out
->
data
<
uint8_t
>
();
auto
*
bias_dropout_residual_out_data
=
bias_dropout_residual_out
->
data
<
T
>
();
// output's grad
// output's grad
auto
*
d_x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
...
@@ -312,8 +314,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -312,8 +314,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
d_fmha_out_data
=
d_fmha_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_fmha_out_data
=
d_fmha_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_out_linear_out_data
=
auto
*
d_out_linear_out_data
=
d_out_linear_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_out_linear_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_bias_dropout_residual_out_data
=
d_bias_dropout_residual_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// parameter grad
// parameter grad
auto
*
d_qkv_weight
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVW"
));
auto
*
d_qkv_weight
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVW"
));
...
@@ -331,12 +331,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -331,12 +331,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_weight
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_out_linear_weight
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_out_linear_bias_data
=
auto
*
d_out_linear_bias_data
=
d_out_linear_bias
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_out_linear_bias
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_ln_2_scale_data
=
(
d_ln_2_scale
==
nullptr
?
nullptr
:
d_ln_2_scale
->
mutable_data
<
U
>
(
ctx
.
GetPlace
()));
auto
*
d_ln_2_bias_data
=
(
d_ln_2_bias
==
nullptr
?
nullptr
:
d_ln_2_bias
->
mutable_data
<
U
>
(
ctx
.
GetPlace
()));
const
auto
input_x_dims
=
input_x
->
dims
();
const
auto
input_x_dims
=
input_x
->
dims
();
const
auto
qkv_w_dims
=
qkv_weight
->
dims
();
const
auto
qkv_w_dims
=
qkv_weight
->
dims
();
...
@@ -382,11 +376,30 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -382,11 +376,30 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
ctx
.
cuda_device_context
(),
bsz_seq
,
dim_embed
,
dropout_param2
,
ctx
.
cuda_device_context
(),
bsz_seq
,
dim_embed
,
dropout_param2
,
ln2epsilon
);
ln2epsilon
);
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBiasGrad
(
if
(
pre_layer_norm
)
{
ctx
.
cuda_device_context
(),
d_y_data
,
bias_dropout_residual_out_data
,
fused_dropout_layernorm_helper
.
ResidualDropoutBiasGrad
(
dropout_mask_out_data
,
ln_2_scale_data
,
ln_2_mean_data
,
ln_2_var_data
,
ctx
.
cuda_device_context
(),
d_y_data
,
dropout_mask_out_data
,
d_bias_dropout_residual_out_data
,
d_ln_2_scale_data
,
d_ln_2_bias_data
,
d_out_linear_out_data
,
d_residual_data
,
d_out_linear_bias_data
);
d_out_linear_out_data
,
d_out_linear_bias_data
,
d_residual_data
);
}
else
{
auto
*
ln_2_mean_data
=
ln_2_mean
->
data
<
U
>
();
auto
*
ln_2_var_data
=
ln_2_var
->
data
<
U
>
();
auto
*
bias_dropout_residual_out_data
=
bias_dropout_residual_out
->
data
<
T
>
();
auto
*
d_ln_2_scale_data
=
(
d_ln_2_scale
==
nullptr
?
nullptr
:
d_ln_2_scale
->
mutable_data
<
U
>
(
ctx
.
GetPlace
()));
auto
*
d_ln_2_bias_data
=
(
d_ln_2_bias
==
nullptr
?
nullptr
:
d_ln_2_bias
->
mutable_data
<
U
>
(
ctx
.
GetPlace
()));
auto
*
d_bias_dropout_residual_out_data
=
d_bias_dropout_residual_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBiasGrad
(
ctx
.
cuda_device_context
(),
d_y_data
,
bias_dropout_residual_out_data
,
dropout_mask_out_data
,
ln_2_scale_data
,
ln_2_mean_data
,
ln_2_var_data
,
d_bias_dropout_residual_out_data
,
d_ln_2_scale_data
,
d_ln_2_bias_data
,
d_out_linear_out_data
,
d_out_linear_bias_data
,
d_residual_data
);
}
out_linear_compute
.
ComputeBackward
(
fmha_out_data
,
out_linear_weight_data
,
out_linear_compute
.
ComputeBackward
(
fmha_out_data
,
out_linear_weight_data
,
d_out_linear_out_data
,
d_fmha_out_data
,
d_out_linear_out_data
,
d_fmha_out_data
,
...
...
python/paddle/fluid/tests/unittests/test_fused_attention_op.py
浏览文件 @
ad44a40c
...
@@ -155,8 +155,8 @@ class TestFusedAttentionOp(OpTest):
...
@@ -155,8 +155,8 @@ class TestFusedAttentionOp(OpTest):
residual_out
=
residual
+
self
.
dropout
(
out
)
residual_out
=
residual
+
self
.
dropout
(
out
)
if
not
self
.
pre_layer_norm
:
if
not
self
.
pre_layer_norm
:
final_out
=
self
.
norm1
(
residual_out
)
final_out
=
self
.
norm1
(
residual_out
)
if
self
.
pre_layer_norm
:
else
:
final_out
=
self
.
norm2
(
residual_out
)
final_out
=
residual_out
paddle
.
autograd
.
backward
(
paddle
.
autograd
.
backward
(
[
final_out
],
[
paddle
.
to_tensor
(
self
.
dout
)],
retain_graph
=
True
)
[
final_out
],
[
paddle
.
to_tensor
(
self
.
dout
)],
retain_graph
=
True
)
return
final_out
,
tensor_query
.
grad
return
final_out
,
tensor_query
.
grad
...
@@ -219,9 +219,9 @@ class TestFusedAttentionOp(OpTest):
...
@@ -219,9 +219,9 @@ class TestFusedAttentionOp(OpTest):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-
5
)
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-
4
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-
5
)
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-
4
)
class
TestFusedAttentionOpPreLn
(
TestFusedAttentionOp
):
class
TestFusedAttentionOpPreLn
(
TestFusedAttentionOp
):
...
@@ -249,9 +249,9 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
...
@@ -249,9 +249,9 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-
1
)
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-
4
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-
1
)
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-
4
)
class
TestFusedAttentionOpNoneAttnMask
(
TestFusedAttentionOp
):
class
TestFusedAttentionOpNoneAttnMask
(
TestFusedAttentionOp
):
...
@@ -279,9 +279,9 @@ class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
...
@@ -279,9 +279,9 @@ class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-
1
)
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-
4
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-
1
)
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-
4
)
class
TestFusedAttentionOpFp16
(
TestFusedAttentionOp
):
class
TestFusedAttentionOpFp16
(
TestFusedAttentionOp
):
...
...
python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py
浏览文件 @
ad44a40c
...
@@ -138,9 +138,11 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
...
@@ -138,9 +138,11 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
out_linear_bias_out
=
out_linear_out
+
out_linear_bias
out_linear_bias_out
=
out_linear_out
+
out_linear_bias
out_linear_bias_dropout_out
=
out_linear_bias_out
out_linear_bias_dropout_out
=
out_linear_bias_out
out_linear_bias_dropout_residual_out
=
query
+
out_linear_bias_dropout_out
out_linear_bias_dropout_residual_out
=
query
+
out_linear_bias_dropout_out
out_linear_bias_dropout_residual_ln_out
=
layer_norm
(
if
not
pre_layer_norm
:
out_linear_bias_dropout_residual_out
,
True
,
True
,
ln_2_scale
,
ln_2_bias
)
out_linear_bias_dropout_residual_out
=
layer_norm
(
return
out_linear_bias_dropout_residual_ln_out
out_linear_bias_dropout_residual_out
,
True
,
True
,
ln_2_scale
,
ln_2_bias
)
return
out_linear_bias_dropout_residual_out
class
TestFusedAttentionAPI
(
unittest
.
TestCase
):
class
TestFusedAttentionAPI
(
unittest
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录