Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ff3018d7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ff3018d7
编写于
10月 28, 2021
作者:
L
Li Min
提交者:
GitHub
10月 28, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix fused_attention_op and fused_feedforward_op bug when pre_layer_norm is false. (#36793)
* Fix bug when pre_layer_norm is false.
上级
9516108a
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
254 addition
and
144 deletion
+254
-144
paddle/fluid/operators/fused/fused_attention_op.cc
paddle/fluid/operators/fused/fused_attention_op.cc
+64
-37
paddle/fluid/operators/fused/fused_attention_op.cu
paddle/fluid/operators/fused/fused_attention_op.cu
+25
-19
paddle/fluid/operators/fused/fused_feedforward_op.cc
paddle/fluid/operators/fused/fused_feedforward_op.cc
+56
-44
paddle/fluid/operators/fused/fused_feedforward_op.cu
paddle/fluid/operators/fused/fused_feedforward_op.cu
+78
-42
python/paddle/fluid/tests/unittests/test_fused_attention_op.py
...n/paddle/fluid/tests/unittests/test_fused_attention_op.py
+31
-2
未找到文件。
paddle/fluid/operators/fused/fused_attention_op.cc
浏览文件 @
ff3018d7
...
@@ -37,12 +37,15 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -37,12 +37,15 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearBias"
),
"Input"
,
"OutLinearBias"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearBias"
),
"Input"
,
"OutLinearBias"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnMean"
),
"Output"
,
"LnMean"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnMean"
),
"Output"
,
"LnMean"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnVariance"
),
"Output"
,
"LnVariance"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnVariance"
),
"Output"
,
"LnVariance"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnOut"
),
"Output"
,
"LnOut"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnOut"
),
"Output"
,
"LnOut"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
}
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKVOut"
),
"Output"
,
"QKVOut"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKVOut"
),
"Output"
,
"QKVOut"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
...
@@ -101,9 +104,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -101,9 +104,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"input qkv_weight = [%s]"
,
"input qkv_weight = [%s]"
,
x_dim
,
y_dim
));
x_dim
,
y_dim
));
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
ctx
->
SetOutputDim
(
"LnMean"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnMean"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnVariance"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnVariance"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnOut"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"LnOut"
,
ctx
->
GetInputDim
(
"X"
));
}
// [batch_size, seq_len, 3, num_head, head_size]
// [batch_size, seq_len, 3, num_head, head_size]
ctx
->
SetOutputDim
(
"QKVOut"
,
ctx
->
SetOutputDim
(
"QKVOut"
,
{
x_dim
[
0
],
x_dim
[
1
],
y_dim
[
0
],
y_dim
[
1
],
y_dim
[
2
]});
{
x_dim
[
0
],
x_dim
[
1
],
y_dim
[
0
],
y_dim
[
1
],
y_dim
[
2
]});
...
@@ -351,11 +356,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
...
@@ -351,11 +356,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx
->
GetInputDim
(
"Ln2Bias"
));
ctx
->
GetInputDim
(
"Ln2Bias"
));
}
}
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FusedAttentionGrad"
);
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnMean"
),
"Input"
,
"LnMean"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnMean"
),
"Input"
,
"LnMean"
,
"FusedAttentionGrad"
);
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnVariance"
),
"Input"
,
"LnVariance"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnVariance"
),
"Input"
,
"LnVariance"
,
"FusedAttentionGrad"
);
"FusedAttentionGrad"
);
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnOut"
),
"Input"
,
"LnOut"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnOut"
),
"Input"
,
"LnOut"
,
"FusedAttentionGrad"
);
"FusedAttentionGrad"
);
}
}
...
@@ -370,6 +375,7 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
...
@@ -370,6 +375,7 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearBias"
),
"Input"
,
"OutLinearBias"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearBias"
),
"Input"
,
"OutLinearBias"
,
"FusedAttentionGrad"
);
"FusedAttentionGrad"
);
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"LnScale"
)))
{
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"LnScale"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnScale"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnScale"
),
ctx
->
GetInputDim
(
"LnScale"
));
ctx
->
GetInputDim
(
"LnScale"
));
...
@@ -378,6 +384,7 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
...
@@ -378,6 +384,7 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnBias"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnBias"
),
ctx
->
GetInputDim
(
"LnBias"
));
ctx
->
GetInputDim
(
"LnBias"
));
}
}
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
}
...
@@ -390,8 +397,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
...
@@ -390,8 +397,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVBias"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVBias"
),
ctx
->
GetInputDim
(
"QKVBias"
));
ctx
->
GetInputDim
(
"QKVBias"
));
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnOut"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnOut"
),
ctx
->
GetInputDim
(
"LnOut"
));
ctx
->
GetInputDim
(
"LnOut"
));
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"FMHAOut"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"FMHAOut"
),
ctx
->
GetInputDim
(
"FMHAOut"
));
ctx
->
GetInputDim
(
"FMHAOut"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKTVOut"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKTVOut"
),
...
@@ -442,6 +451,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -442,6 +451,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op
->
SetInput
(
"SrcMask"
,
this
->
Input
(
"SrcMask"
));
op
->
SetInput
(
"SrcMask"
,
this
->
Input
(
"SrcMask"
));
op
->
SetInput
(
"OutLinearW"
,
this
->
Input
(
"OutLinearW"
));
op
->
SetInput
(
"OutLinearW"
,
this
->
Input
(
"OutLinearW"
));
op
->
SetInput
(
"OutLinearBias"
,
this
->
Input
(
"OutLinearBias"
));
op
->
SetInput
(
"OutLinearBias"
,
this
->
Input
(
"OutLinearBias"
));
op
->
SetAttrMap
(
this
->
Attrs
());
bool
is_pre_layer_norm
=
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"pre_layer_norm"
));
if
(
is_pre_layer_norm
)
{
if
(
this
->
HasInput
(
"LnScale"
))
{
if
(
this
->
HasInput
(
"LnScale"
))
{
op
->
SetInput
(
"LnScale"
,
this
->
Input
(
"LnScale"
));
op
->
SetInput
(
"LnScale"
,
this
->
Input
(
"LnScale"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"LnScale"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"LnScale"
),
...
@@ -452,6 +466,8 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -452,6 +466,8 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op
->
SetOutput
(
framework
::
GradVarName
(
"LnBias"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"LnBias"
),
this
->
InputGrad
(
"LnBias"
));
this
->
InputGrad
(
"LnBias"
));
}
}
}
if
(
this
->
HasInput
(
"Ln2Scale"
))
{
if
(
this
->
HasInput
(
"Ln2Scale"
))
{
op
->
SetInput
(
"Ln2Scale"
,
this
->
Input
(
"Ln2Scale"
));
op
->
SetInput
(
"Ln2Scale"
,
this
->
Input
(
"Ln2Scale"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Scale"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Scale"
),
...
@@ -473,9 +489,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -473,9 +489,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this
->
InputGrad
(
"OutLinearW"
));
this
->
InputGrad
(
"OutLinearW"
));
// use forward outputs as backward inputs.
// use forward outputs as backward inputs.
if
(
is_pre_layer_norm
)
{
if
(
this
->
HasOutput
(
"LnOut"
))
{
op
->
SetInput
(
"LnOut"
,
this
->
Output
(
"LnOut"
));
op
->
SetInput
(
"LnOut"
,
this
->
Output
(
"LnOut"
));
}
if
(
this
->
HasOutput
(
"LnMean"
))
{
op
->
SetInput
(
"LnMean"
,
this
->
Output
(
"LnMean"
));
op
->
SetInput
(
"LnMean"
,
this
->
Output
(
"LnMean"
));
}
if
(
this
->
HasOutput
(
"LnVariance"
))
{
op
->
SetInput
(
"LnVariance"
,
this
->
Output
(
"LnVariance"
));
op
->
SetInput
(
"LnVariance"
,
this
->
Output
(
"LnVariance"
));
}
}
op
->
SetInput
(
"QKVOut"
,
this
->
Output
(
"QKVOut"
));
op
->
SetInput
(
"QKVOut"
,
this
->
Output
(
"QKVOut"
));
op
->
SetInput
(
"QKVBiasOut"
,
this
->
Output
(
"QKVBiasOut"
));
op
->
SetInput
(
"QKVBiasOut"
,
this
->
Output
(
"QKVBiasOut"
));
op
->
SetInput
(
"TransposeOut2"
,
this
->
Output
(
"TransposeOut2"
));
op
->
SetInput
(
"TransposeOut2"
,
this
->
Output
(
"TransposeOut2"
));
...
@@ -496,7 +520,12 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -496,7 +520,12 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op
->
SetInput
(
"QKVOut"
,
this
->
Output
(
"QKVOut"
));
op
->
SetInput
(
"QKVOut"
,
this
->
Output
(
"QKVOut"
));
// backward outputs: dinput
// backward outputs: dinput
op
->
SetOutput
(
framework
::
GradVarName
(
"LnOut"
),
this
->
OutputGrad
(
"LnOut"
));
if
(
is_pre_layer_norm
)
{
if
(
this
->
HasOutput
(
"LnOut"
))
{
op
->
SetOutput
(
framework
::
GradVarName
(
"LnOut"
),
this
->
OutputGrad
(
"LnOut"
));
}
}
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVOut"
),
this
->
OutputGrad
(
"QKVOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVOut"
),
this
->
OutputGrad
(
"QKVOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVBiasOut"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVBiasOut"
),
this
->
OutputGrad
(
"QKVBiasOut"
));
this
->
OutputGrad
(
"QKVBiasOut"
));
...
@@ -517,8 +546,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -517,8 +546,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this
->
OutputGrad
(
"BiasDropoutResidualOut"
));
this
->
OutputGrad
(
"BiasDropoutResidualOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"OutLinearOut"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"OutLinearOut"
),
this
->
OutputGrad
(
"OutLinearOut"
));
this
->
OutputGrad
(
"OutLinearOut"
));
op
->
SetAttrMap
(
this
->
Attrs
());
}
}
};
};
...
...
paddle/fluid/operators/fused/fused_attention_op.cu
浏览文件 @
ff3018d7
...
@@ -97,9 +97,12 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -97,9 +97,12 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto
*
x_data
=
input_x
->
data
<
T
>
();
auto
*
x_data
=
input_x
->
data
<
T
>
();
auto
*
ln_scale_data
=
(
ln_scale
==
nullptr
?
nullptr
:
ln_scale
->
data
<
U
>
());
auto
*
ln_scale_data
=
(
ln_scale
==
nullptr
?
nullptr
:
ln_scale
->
data
<
U
>
());
auto
*
ln_bias_data
=
(
ln_bias
==
nullptr
?
nullptr
:
ln_bias
->
data
<
U
>
());
auto
*
ln_bias_data
=
(
ln_bias
==
nullptr
?
nullptr
:
ln_bias
->
data
<
U
>
());
auto
*
ln_mean_data
=
ln_mean
->
mutable_data
<
U
>
(
ctx
.
GetPlace
());
auto
*
ln_mean_data
=
auto
*
ln_var_data
=
ln_var
->
mutable_data
<
U
>
(
ctx
.
GetPlace
());
pre_layer_norm
?
ln_mean
->
mutable_data
<
U
>
(
ctx
.
GetPlace
())
:
nullptr
;
auto
*
ln_out_data
=
ln_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
ln_var_data
=
pre_layer_norm
?
ln_var
->
mutable_data
<
U
>
(
ctx
.
GetPlace
())
:
nullptr
;
auto
*
ln_out_data
=
pre_layer_norm
?
ln_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())
:
nullptr
;
auto
*
qkv_weight_data
=
qkv_weight
->
data
<
T
>
();
auto
*
qkv_weight_data
=
qkv_weight
->
data
<
T
>
();
auto
*
qkv_bias_data
=
qkv_bias
->
data
<
T
>
();
auto
*
qkv_bias_data
=
qkv_bias
->
data
<
T
>
();
...
@@ -243,9 +246,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -243,9 +246,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
out_linear_bias_data
=
out_linear_bias
->
data
<
T
>
();
auto
*
out_linear_bias_data
=
out_linear_bias
->
data
<
T
>
();
// fw output
// fw output
auto
*
ln_mean
=
ctx
.
Input
<
Tensor
>
(
"LnMean"
);
auto
*
ln_var
=
ctx
.
Input
<
Tensor
>
(
"LnVariance"
);
auto
*
ln_out
=
ctx
.
Input
<
Tensor
>
(
"LnOut"
);
auto
*
fmha_out
=
ctx
.
Input
<
Tensor
>
(
"FMHAOut"
);
auto
*
fmha_out
=
ctx
.
Input
<
Tensor
>
(
"FMHAOut"
);
auto
*
transpose_out_2
=
ctx
.
Input
<
Tensor
>
(
"TransposeOut2"
);
auto
*
transpose_out_2
=
ctx
.
Input
<
Tensor
>
(
"TransposeOut2"
);
auto
*
qk_out
=
ctx
.
Input
<
Tensor
>
(
"QKOut"
);
auto
*
qk_out
=
ctx
.
Input
<
Tensor
>
(
"QKOut"
);
...
@@ -260,9 +260,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -260,9 +260,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
dropout_mask_out
=
ctx
.
Input
<
Tensor
>
(
"DropoutMaskOut"
);
auto
*
dropout_mask_out
=
ctx
.
Input
<
Tensor
>
(
"DropoutMaskOut"
);
auto
*
bias_dropout_residual_out
=
auto
*
bias_dropout_residual_out
=
ctx
.
Input
<
Tensor
>
(
"BiasDropoutResidualOut"
);
ctx
.
Input
<
Tensor
>
(
"BiasDropoutResidualOut"
);
auto
*
ln_mean_data
=
ln_mean
->
data
<
U
>
();
auto
*
ln_var_data
=
ln_var
->
data
<
U
>
();
auto
*
ln_out_data
=
ln_out
->
data
<
T
>
();
auto
*
fmha_out_data
=
fmha_out
->
data
<
T
>
();
auto
*
fmha_out_data
=
fmha_out
->
data
<
T
>
();
auto
*
transpose_out_2_data
=
transpose_out_2
->
data
<
T
>
();
auto
*
transpose_out_2_data
=
transpose_out_2
->
data
<
T
>
();
auto
*
qk_out_data
=
qk_out
->
data
<
T
>
();
auto
*
qk_out_data
=
qk_out
->
data
<
T
>
();
...
@@ -277,7 +274,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -277,7 +274,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
// output's grad
// output's grad
auto
*
d_x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_ln_out
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"LnOut"
));
auto
*
d_qkv_out
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVOut"
));
auto
*
d_qkv_out
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVOut"
));
auto
*
d_qkv_bias_out
=
auto
*
d_qkv_bias_out
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVBiasOut"
));
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVBiasOut"
));
...
@@ -297,7 +293,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -297,7 +293,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
d_bias_dropout_residual_out
=
auto
*
d_bias_dropout_residual_out
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"BiasDropoutResidualOut"
));
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"BiasDropoutResidualOut"
));
auto
*
d_x_data
=
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_x_data
=
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_ln_out_data
=
d_ln_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_out_data
=
d_qkv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_out_data
=
d_qkv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_bias_out_data
=
d_qkv_bias_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_bias_out_data
=
d_qkv_bias_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qktv_out_data
=
d_qktv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qktv_out_data
=
d_qktv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
@@ -315,8 +310,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -315,8 +310,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_bias_dropout_residual_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_bias_dropout_residual_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// parameter grad
// parameter grad
auto
*
d_ln_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"LnScale"
));
auto
*
d_ln_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"LnBias"
));
auto
*
d_qkv_weight
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVW"
));
auto
*
d_qkv_weight
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVW"
));
auto
*
d_qkv_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVBias"
));
auto
*
d_qkv_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVBias"
));
auto
*
d_out_linear_weight
=
auto
*
d_out_linear_weight
=
...
@@ -325,12 +318,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -325,12 +318,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"OutLinearBias"
));
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"OutLinearBias"
));
auto
*
d_ln_2_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Ln2Scale"
));
auto
*
d_ln_2_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Ln2Scale"
));
auto
*
d_ln_2_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Ln2Bias"
));
auto
*
d_ln_2_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Ln2Bias"
));
auto
*
d_ln_scale_data
=
(
d_ln_scale
==
nullptr
?
nullptr
:
d_ln_scale
->
mutable_data
<
U
>
(
ctx
.
GetPlace
()));
auto
*
d_ln_bias_data
=
(
d_ln_bias
==
nullptr
?
nullptr
:
d_ln_bias
->
mutable_data
<
U
>
(
ctx
.
GetPlace
()));
auto
*
d_qkv_weight_data
=
d_qkv_weight
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_weight_data
=
d_qkv_weight
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_bias_data
=
d_qkv_bias
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_bias_data
=
d_qkv_bias
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_out_linear_weight_data
=
auto
*
d_out_linear_weight_data
=
...
@@ -407,6 +395,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -407,6 +395,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
cudaMemcpyDeviceToDevice
);
cudaMemcpyDeviceToDevice
);
if
(
pre_layer_norm
)
{
if
(
pre_layer_norm
)
{
auto
*
ln_mean
=
ctx
.
Input
<
Tensor
>
(
"LnMean"
);
auto
*
ln_var
=
ctx
.
Input
<
Tensor
>
(
"LnVariance"
);
auto
*
ln_out
=
ctx
.
Input
<
Tensor
>
(
"LnOut"
);
auto
*
ln_mean_data
=
ln_mean
->
data
<
U
>
();
auto
*
ln_var_data
=
ln_var
->
data
<
U
>
();
auto
*
ln_out_data
=
ln_out
->
data
<
T
>
();
auto
*
d_ln_out
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"LnOut"
));
auto
*
d_ln_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"LnScale"
));
auto
*
d_ln_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"LnBias"
));
auto
*
d_ln_out_data
=
d_ln_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_ln_scale_data
=
(
d_ln_scale
==
nullptr
?
nullptr
:
d_ln_scale
->
mutable_data
<
U
>
(
ctx
.
GetPlace
()));
auto
*
d_ln_bias_data
=
(
d_ln_bias
==
nullptr
?
nullptr
:
d_ln_bias
->
mutable_data
<
U
>
(
ctx
.
GetPlace
()));
qkv_compute
.
ComputeBackward
(
ln_out_data
,
qkv_weight_data
,
qkv_compute
.
ComputeBackward
(
ln_out_data
,
qkv_weight_data
,
d_qkv_bias_out_data
,
d_ln_out_data
,
d_qkv_bias_out_data
,
d_ln_out_data
,
d_qkv_weight_data
,
d_qkv_bias_data
);
d_qkv_weight_data
,
d_qkv_bias_data
);
...
...
paddle/fluid/operators/fused/fused_feedforward_op.cc
浏览文件 @
ff3018d7
...
@@ -41,18 +41,8 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
...
@@ -41,18 +41,8 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
"fused_feedforward"
);
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Dropout2Mask"
),
"Output"
,
"Dropout2Mask"
,
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Dropout2Mask"
),
"Output"
,
"Dropout2Mask"
,
"fused_feedforward"
);
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln1Mean"
),
"Output"
,
"Ln1Mean"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln1Variance"
),
"Output"
,
"Ln1Variance"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln2Mean"
),
"Output"
,
"Ln2Mean"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln2Variance"
),
"Output"
,
"Ln2Variance"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Linear1Out"
),
"Output"
,
"Linear1Out"
,
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Linear1Out"
),
"Output"
,
"Linear1Out"
,
"fused_feedforward"
);
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln1Out"
),
"Output"
,
"Ln1Out"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Dropout1Out"
),
"Output"
,
"Dropout1Out"
,
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Dropout1Out"
),
"Output"
,
"Dropout1Out"
,
"fused_feedforward"
);
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Dropout2Out"
),
"Output"
,
"Dropout2Out"
,
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Dropout2Out"
),
"Output"
,
"Dropout2Out"
,
...
@@ -76,7 +66,6 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
...
@@ -76,7 +66,6 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
}
}
context
->
SetOutputDim
(
"Dropout1Out"
,
tmp_dim_x
);
context
->
SetOutputDim
(
"Dropout1Out"
,
tmp_dim_x
);
context
->
SetOutputDim
(
"Linear1Out"
,
tmp_dim_x
);
context
->
SetOutputDim
(
"Linear1Out"
,
tmp_dim_x
);
context
->
SetOutputDim
(
"Ln1Out"
,
dim_x
);
context
->
SetOutputDim
(
"Dropout2Out"
,
dim_x
);
context
->
SetOutputDim
(
"Dropout2Out"
,
dim_x
);
if
(
context
->
Attrs
().
Get
<
bool
>
(
"dropout2_is_test"
)
==
false
)
{
if
(
context
->
Attrs
().
Get
<
bool
>
(
"dropout2_is_test"
)
==
false
)
{
...
@@ -84,10 +73,25 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
...
@@ -84,10 +73,25 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
}
}
framework
::
DDim
mean_dim
=
framework
::
DDim
mean_dim
=
framework
::
make_ddim
({
mat_dim_x
.
batch_size_
*
mat_dim_x
.
height_
});
framework
::
make_ddim
({
mat_dim_x
.
batch_size_
*
mat_dim_x
.
height_
});
bool
pre_layer_norm
=
context
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
);
if
(
pre_layer_norm
)
{
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln1Mean"
),
"Output"
,
"Ln1Mean"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln1Variance"
),
"Output"
,
"Ln1Variance"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln1Out"
),
"Output"
,
"Ln1Out"
,
"fused_feedforward"
);
context
->
SetOutputDim
(
"Ln1Out"
,
dim_x
);
context
->
SetOutputDim
(
"Ln1Mean"
,
mean_dim
);
context
->
SetOutputDim
(
"Ln1Mean"
,
mean_dim
);
context
->
SetOutputDim
(
"Ln1Variance"
,
mean_dim
);
context
->
SetOutputDim
(
"Ln1Variance"
,
mean_dim
);
}
else
{
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln2Mean"
),
"Output"
,
"Ln2Mean"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln2Variance"
),
"Output"
,
"Ln2Variance"
,
"fused_feedforward"
);
context
->
SetOutputDim
(
"Ln2Mean"
,
mean_dim
);
context
->
SetOutputDim
(
"Ln2Mean"
,
mean_dim
);
context
->
SetOutputDim
(
"Ln2Variance"
,
mean_dim
);
context
->
SetOutputDim
(
"Ln2Variance"
,
mean_dim
);
}
context
->
ShareLoD
(
"X"
,
"Out"
);
context
->
ShareLoD
(
"X"
,
"Out"
);
}
}
...
@@ -218,14 +222,13 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
...
@@ -218,14 +222,13 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ
(
ctx
->
Attrs
().
Get
<
bool
>
(
"dropout2_is_test"
),
false
,
PADDLE_ENFORCE_EQ
(
ctx
->
Attrs
().
Get
<
bool
>
(
"dropout2_is_test"
),
false
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"GradOp is only callable when is_test is false"
));
"GradOp is only callable when is_test is false"
));
bool
pre_layer_norm
=
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout1Mask"
),
"Input"
,
"Dropout1Mask"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout1Mask"
),
"Input"
,
"Dropout1Mask"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout2Mask"
),
"Input"
,
"Dropout1Mask"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout2Mask"
),
"Input"
,
"Dropout1Mask"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Linear1Out"
),
"Input"
,
"Linear1Out"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Linear1Out"
),
"Input"
,
"Linear1Out"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln1Out"
),
"Input"
,
"Ln1Out"
,
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout1Out"
),
"Input"
,
"Dropout1Out"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout1Out"
),
"Input"
,
"Dropout1Out"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout2Out"
),
"Input"
,
"Dropout2Out"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout2Out"
),
"Input"
,
"Dropout2Out"
,
...
@@ -234,14 +237,19 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
...
@@ -234,14 +237,19 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Linear2Weight"
),
"Input"
,
"Linear2Weight"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Linear2Weight"
),
"Input"
,
"Linear2Weight"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
if
(
pre_layer_norm
)
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln1Mean"
),
"Input"
,
"Ln1Mean"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln1Mean"
),
"Input"
,
"Ln1Mean"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln1Variance"
),
"Input"
,
"Ln1Variance"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln1Variance"
),
"Input"
,
"Ln1Variance"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln1Out"
),
"Input"
,
"Ln1Out"
,
"FusedFeedForwardGrad"
);
}
else
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln2Mean"
),
"Input"
,
"Ln2Mean"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln2Mean"
),
"Input"
,
"Ln2Mean"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln2Variance"
),
"Input"
,
"Ln2Variance"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln2Variance"
),
"Input"
,
"Ln2Variance"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
}
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
framework
::
GradVarName
(
"Out"
),
"FusedFeedForwardGrad"
);
framework
::
GradVarName
(
"Out"
),
"FusedFeedForwardGrad"
);
...
@@ -299,30 +307,36 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
...
@@ -299,30 +307,36 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
op
->
SetInput
(
"Linear1Weight"
,
this
->
Input
(
"Linear1Weight"
));
op
->
SetInput
(
"Linear1Weight"
,
this
->
Input
(
"Linear1Weight"
));
op
->
SetInput
(
"Linear1Bias"
,
this
->
Input
(
"Linear1Bias"
));
op
->
SetInput
(
"Linear1Bias"
,
this
->
Input
(
"Linear1Bias"
));
op
->
SetInput
(
"Linear2Weight"
,
this
->
Input
(
"Linear2Weight"
));
op
->
SetInput
(
"Linear2Weight"
,
this
->
Input
(
"Linear2Weight"
));
op
->
SetInput
(
"Ln1Scale"
,
this
->
Input
(
"Ln1Scale"
));
op
->
SetInput
(
"Ln1Bias"
,
this
->
Input
(
"Ln1Bias"
));
op
->
SetInput
(
"Ln2Scale"
,
this
->
Input
(
"Ln2Scale"
));
op
->
SetInput
(
"Ln2Bias"
,
this
->
Input
(
"Ln2Bias"
));
op
->
SetInput
(
"Dropout1Mask"
,
this
->
Output
(
"Dropout1Mask"
));
op
->
SetInput
(
"Dropout1Mask"
,
this
->
Output
(
"Dropout1Mask"
));
op
->
SetInput
(
"Dropout2Mask"
,
this
->
Output
(
"Dropout2Mask"
));
op
->
SetInput
(
"Dropout2Mask"
,
this
->
Output
(
"Dropout2Mask"
));
op
->
SetInput
(
"Linear1Out"
,
this
->
Output
(
"Linear1Out"
));
op
->
SetInput
(
"Linear1Out"
,
this
->
Output
(
"Linear1Out"
));
op
->
SetInput
(
"Ln1Out"
,
this
->
Output
(
"Ln1Out"
));
op
->
SetInput
(
"Ln1Mean"
,
this
->
Output
(
"Ln1Mean"
));
op
->
SetInput
(
"Ln1Variance"
,
this
->
Output
(
"Ln1Variance"
));
op
->
SetInput
(
"Ln2Mean"
,
this
->
Output
(
"Ln2Mean"
));
op
->
SetInput
(
"Ln2Variance"
,
this
->
Output
(
"Ln2Variance"
));
op
->
SetInput
(
"Dropout1Out"
,
this
->
Output
(
"Dropout1Out"
));
op
->
SetInput
(
"Dropout1Out"
,
this
->
Output
(
"Dropout1Out"
));
op
->
SetInput
(
"Dropout2Out"
,
this
->
Output
(
"Dropout2Out"
));
op
->
SetInput
(
"Dropout2Out"
,
this
->
Output
(
"Dropout2Out"
));
op
->
SetAttrMap
(
this
->
Attrs
());
bool
pre_layer_norm
=
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"pre_layer_norm"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
if
(
pre_layer_norm
)
{
op
->
SetInput
(
"Ln1Scale"
,
this
->
Input
(
"Ln1Scale"
));
op
->
SetInput
(
"Ln1Bias"
,
this
->
Input
(
"Ln1Bias"
));
op
->
SetInput
(
"Ln1Out"
,
this
->
Output
(
"Ln1Out"
));
op
->
SetInput
(
"Ln1Mean"
,
this
->
Output
(
"Ln1Mean"
));
op
->
SetInput
(
"Ln1Variance"
,
this
->
Output
(
"Ln1Variance"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln1Scale"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln1Scale"
),
this
->
InputGrad
(
"Ln1Scale"
));
this
->
InputGrad
(
"Ln1Scale"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln1Bias"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln1Bias"
),
this
->
InputGrad
(
"Ln1Bias"
));
this
->
InputGrad
(
"Ln1Bias"
));
}
else
{
op
->
SetInput
(
"Ln2Scale"
,
this
->
Input
(
"Ln2Scale"
));
op
->
SetInput
(
"Ln2Bias"
,
this
->
Input
(
"Ln2Bias"
));
op
->
SetInput
(
"Ln2Mean"
,
this
->
Output
(
"Ln2Mean"
));
op
->
SetInput
(
"Ln2Variance"
,
this
->
Output
(
"Ln2Variance"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Scale"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Scale"
),
this
->
InputGrad
(
"Ln2Scale"
));
this
->
InputGrad
(
"Ln2Scale"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Bias"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Bias"
),
this
->
InputGrad
(
"Ln2Bias"
));
this
->
InputGrad
(
"Ln2Bias"
));
}
op
->
SetOutput
(
framework
::
GradVarName
(
"Linear1Weight"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Linear1Weight"
),
this
->
InputGrad
(
"Linear1Weight"
));
this
->
InputGrad
(
"Linear1Weight"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Linear1Bias"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Linear1Bias"
),
...
@@ -334,8 +348,6 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
...
@@ -334,8 +348,6 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
op
->
SetOutput
(
framework
::
GradVarName
(
"Linear2Bias"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Linear2Bias"
),
this
->
InputGrad
(
"Linear2Bias"
));
this
->
InputGrad
(
"Linear2Bias"
));
}
}
op
->
SetAttrMap
(
this
->
Attrs
());
}
}
};
};
...
...
paddle/fluid/operators/fused/fused_feedforward_op.cu
浏览文件 @
ff3018d7
...
@@ -113,26 +113,40 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
...
@@ -113,26 +113,40 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
auto
*
linear1_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Bias"
);
auto
*
linear1_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Bias"
);
auto
*
linear2_weight
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear2Weight"
);
auto
*
linear2_weight
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear2Weight"
);
auto
*
linear2_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear2Bias"
);
auto
*
linear2_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear2Bias"
);
auto
*
ln1_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Scale"
);
const
bool
pre_layer_norm
=
context
.
Attr
<
bool
>
(
"pre_layer_norm"
);
auto
*
ln1_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Bias"
);
auto
*
ln2_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Scale"
);
auto
*
ln1_scale
=
auto
*
ln2_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Bias"
);
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Scale"
)
:
nullptr
;
auto
*
ln1_bias
=
auto
*
ln1_mean
=
context
.
Output
<
framework
::
Tensor
>
(
"Ln1Mean"
);
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Bias"
)
:
nullptr
;
auto
*
ln1_variance
=
context
.
Output
<
framework
::
Tensor
>
(
"Ln1Variance"
);
auto
*
ln2_scale
=
!
pre_layer_norm
auto
*
ln2_mean
=
context
.
Output
<
framework
::
Tensor
>
(
"Ln2Mean"
);
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Scale"
)
auto
*
ln2_variance
=
context
.
Output
<
framework
::
Tensor
>
(
"Ln2Variance"
);
:
nullptr
;
auto
*
ln2_bias
=
!
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Bias"
)
:
nullptr
;
auto
*
ln1_mean
=
pre_layer_norm
?
context
.
Output
<
framework
::
Tensor
>
(
"Ln1Mean"
)
:
nullptr
;
auto
*
ln1_variance
=
pre_layer_norm
?
context
.
Output
<
framework
::
Tensor
>
(
"Ln1Variance"
)
:
nullptr
;
auto
*
ln2_mean
=
!
pre_layer_norm
?
context
.
Output
<
framework
::
Tensor
>
(
"Ln2Mean"
)
:
nullptr
;
auto
*
ln2_variance
=
!
pre_layer_norm
?
context
.
Output
<
framework
::
Tensor
>
(
"Ln2Variance"
)
:
nullptr
;
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
dropout1_mask
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout1Mask"
);
auto
*
dropout1_mask
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout1Mask"
);
auto
*
dropout2_mask
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout2Mask"
);
auto
*
dropout2_mask
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout2Mask"
);
auto
*
linear1_out
=
context
.
Output
<
framework
::
Tensor
>
(
"Linear1Out"
);
auto
*
linear1_out
=
context
.
Output
<
framework
::
Tensor
>
(
"Linear1Out"
);
auto
*
ln1_out
=
context
.
Output
<
framework
::
Tensor
>
(
"Ln1Out"
);
auto
*
ln1_out
=
pre_layer_norm
?
context
.
Output
<
framework
::
Tensor
>
(
"Ln1Out"
)
:
nullptr
;
auto
*
dropout1_out
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout1Out"
);
auto
*
dropout1_out
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout1Out"
);
auto
*
dropout2_out
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout2Out"
);
auto
*
dropout2_out
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout2Out"
);
const
std
::
string
act_method
=
context
.
Attr
<
std
::
string
>
(
"act_method"
);
const
std
::
string
act_method
=
context
.
Attr
<
std
::
string
>
(
"act_method"
);
const
bool
pre_layer_norm
=
context
.
Attr
<
bool
>
(
"pre_layer_norm"
);
const
float
epsilon1
=
context
.
Attr
<
float
>
(
"ln1_epsilon"
);
const
float
epsilon1
=
context
.
Attr
<
float
>
(
"ln1_epsilon"
);
const
float
epsilon2
=
context
.
Attr
<
float
>
(
"ln2_epsilon"
);
const
float
epsilon2
=
context
.
Attr
<
float
>
(
"ln2_epsilon"
);
...
@@ -144,12 +158,16 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
...
@@ -144,12 +158,16 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
out
->
mutable_data
<
T
>
(
place
);
out
->
mutable_data
<
T
>
(
place
);
dropout1_mask
->
mutable_data
<
uint8_t
>
(
place
);
dropout1_mask
->
mutable_data
<
uint8_t
>
(
place
);
dropout2_mask
->
mutable_data
<
uint8_t
>
(
place
);
dropout2_mask
->
mutable_data
<
uint8_t
>
(
place
);
if
(
pre_layer_norm
)
{
ln1_mean
->
mutable_data
<
U
>
(
place
);
ln1_mean
->
mutable_data
<
U
>
(
place
);
ln1_variance
->
mutable_data
<
U
>
(
place
);
ln1_variance
->
mutable_data
<
U
>
(
place
);
ln1_out
->
mutable_data
<
T
>
(
place
);
}
else
{
ln2_mean
->
mutable_data
<
U
>
(
place
);
ln2_mean
->
mutable_data
<
U
>
(
place
);
ln2_variance
->
mutable_data
<
U
>
(
place
);
ln2_variance
->
mutable_data
<
U
>
(
place
);
}
linear1_out
->
mutable_data
<
T
>
(
place
);
linear1_out
->
mutable_data
<
T
>
(
place
);
ln1_out
->
mutable_data
<
T
>
(
place
);
dropout1_out
->
mutable_data
<
T
>
(
place
);
dropout1_out
->
mutable_data
<
T
>
(
place
);
dropout2_out
->
mutable_data
<
T
>
(
place
);
dropout2_out
->
mutable_data
<
T
>
(
place
);
...
@@ -193,16 +211,16 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
...
@@ -193,16 +211,16 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const
framework
::
Tensor
&
d_out
,
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
d_out
,
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
dropout1_mask
,
const
framework
::
Tensor
&
dropout1_mask
,
const
framework
::
Tensor
&
dropout2_mask
,
const
framework
::
Tensor
&
dropout2_mask
,
const
framework
::
Tensor
&
linear1_out
,
const
framework
::
Tensor
&
ln1_out
,
const
framework
::
Tensor
&
linear1_out
,
const
framework
::
Tensor
*
ln1_out
,
const
framework
::
Tensor
&
dropout1_out
,
const
framework
::
Tensor
&
dropout1_out
,
const
framework
::
Tensor
&
dropout2_out
,
const
framework
::
Tensor
&
dropout2_out
,
const
framework
::
Tensor
&
linear1_weight
,
const
framework
::
Tensor
&
linear1_weight
,
const
framework
::
Tensor
*
linear1_bias
,
const
framework
::
Tensor
*
linear1_bias
,
const
framework
::
Tensor
&
linear2_weight
,
const
framework
::
Tensor
&
linear2_weight
,
const
framework
::
Tensor
*
ln1_gamma
,
const
framework
::
Tensor
*
ln1_beta
,
const
framework
::
Tensor
*
ln1_gamma
,
const
framework
::
Tensor
*
ln1_beta
,
const
framework
::
Tensor
&
ln1_mean
,
const
framework
::
Tensor
&
ln1_variance
,
const
framework
::
Tensor
*
ln1_mean
,
const
framework
::
Tensor
*
ln1_variance
,
const
framework
::
Tensor
*
ln2_gamma
,
const
framework
::
Tensor
*
ln2_beta
,
const
framework
::
Tensor
*
ln2_gamma
,
const
framework
::
Tensor
*
ln2_beta
,
const
framework
::
Tensor
&
ln2_mean
,
const
framework
::
Tensor
&
ln2_variance
,
const
framework
::
Tensor
*
ln2_mean
,
const
framework
::
Tensor
*
ln2_variance
,
framework
::
Tensor
*
d_x
,
framework
::
Tensor
*
d_linear1_weight
,
framework
::
Tensor
*
d_x
,
framework
::
Tensor
*
d_linear1_weight
,
framework
::
Tensor
*
d_linear1_bias
,
framework
::
Tensor
*
d_linear2_weight
,
framework
::
Tensor
*
d_linear1_bias
,
framework
::
Tensor
*
d_linear2_weight
,
framework
::
Tensor
*
d_linear2_bias
,
framework
::
Tensor
*
d_ln1_gamma
,
framework
::
Tensor
*
d_linear2_bias
,
framework
::
Tensor
*
d_ln1_gamma
,
...
@@ -252,8 +270,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
...
@@ -252,8 +270,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
}
else
{
}
else
{
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBiasGrad
(
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBiasGrad
(
ctx
,
d_out
.
data
<
T
>
(),
dropout2_out
.
data
<
T
>
(),
ctx
,
d_out
.
data
<
T
>
(),
dropout2_out
.
data
<
T
>
(),
dropout2_mask
.
data
<
uint8_t
>
(),
ln2_gamma_ptr
,
ln2_mean
.
data
<
U
>
(),
dropout2_mask
.
data
<
uint8_t
>
(),
ln2_gamma_ptr
,
ln2_mean
->
data
<
U
>
(),
ln2_variance
.
data
<
U
>
(),
d_dropout2_out
.
data
<
T
>
(),
d_ln2_gamma_ptr
,
ln2_variance
->
data
<
U
>
(),
d_dropout2_out
.
data
<
T
>
(),
d_ln2_gamma_ptr
,
d_ln2_beta_ptr
,
d_linear2_out
.
data
<
T
>
(),
d_linear2_bias_ptr
,
d_ln2_beta_ptr
,
d_linear2_out
.
data
<
T
>
(),
d_linear2_bias_ptr
,
d_residual
.
data
<
T
>
());
d_residual
.
data
<
T
>
());
}
}
...
@@ -273,12 +291,12 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
...
@@ -273,12 +291,12 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
if
(
pre_layer_norm
)
{
if
(
pre_layer_norm
)
{
framework
::
Tensor
d_ln1_out
;
framework
::
Tensor
d_ln1_out
;
d_ln1_out
.
mutable_data
<
T
>
({
bsz_seq
,
d_model
},
place
);
d_ln1_out
.
mutable_data
<
T
>
({
bsz_seq
,
d_model
},
place
);
MatMulGrad
(
ctx
,
d_linear1_out
,
ln1_out
,
linear1_weight
,
&
d_ln1_out
,
MatMulGrad
(
ctx
,
d_linear1_out
,
*
ln1_out
,
linear1_weight
,
&
d_ln1_out
,
d_linear1_weight
);
d_linear1_weight
);
pre_layernorm_helper
.
LayerNormGrad
(
ctx
,
d_ln1_out
.
data
<
T
>
(),
x
.
data
<
T
>
(),
pre_layernorm_helper
.
LayerNormGrad
(
ln1_gamma_ptr
,
ln1_mean
.
data
<
U
>
()
,
ctx
,
d_ln1_out
.
data
<
T
>
(),
x
.
data
<
T
>
(),
ln1_gamma_ptr
,
ln1_variance
.
data
<
U
>
(),
d_x
->
data
<
T
>
(),
ln1_mean
->
data
<
U
>
(),
ln1_variance
->
data
<
U
>
(),
d_x
->
data
<
T
>
(),
d_ln1_gamma_ptr
,
d_ln1_beta_ptr
);
d_ln1_gamma_ptr
,
d_ln1_beta_ptr
);
}
else
{
}
else
{
MatMulGrad
(
ctx
,
d_linear1_out
,
x
,
linear1_weight
,
d_x
,
d_linear1_weight
);
MatMulGrad
(
ctx
,
d_linear1_out
,
x
,
linear1_weight
,
d_x
,
d_linear1_weight
);
...
@@ -290,33 +308,52 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
...
@@ -290,33 +308,52 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
auto
d_out
=
auto
d_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
*
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
x
=
*
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
x
=
*
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
bool
pre_layer_norm
=
context
.
Attr
<
bool
>
(
"pre_layer_norm"
);
auto
dropout1_mask
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout1Mask"
);
auto
dropout1_mask
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout1Mask"
);
auto
dropout2_mask
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout2Mask"
);
auto
dropout2_mask
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout2Mask"
);
auto
linear1_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Out"
);
auto
linear1_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Out"
);
auto
ln1_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Out"
);
auto
*
ln1_out
=
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Out"
)
:
nullptr
;
auto
dropout1_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout1Out"
);
auto
dropout1_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout1Out"
);
auto
dropout2_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout2Out"
);
auto
dropout2_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout2Out"
);
auto
linear1_weight
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Weight"
);
auto
linear1_weight
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Weight"
);
auto
*
linear1_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Bias"
);
auto
*
linear1_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Bias"
);
auto
linear2_weight
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Linear2Weight"
);
auto
linear2_weight
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Linear2Weight"
);
auto
ln1_mean
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Mean"
);
auto
*
ln1_mean
=
auto
ln1_variance
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Variance"
);
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Mean"
)
:
nullptr
;
auto
*
ln1_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Scale"
);
auto
*
ln1_variance
=
pre_layer_norm
auto
*
ln1_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Bias"
);
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Variance"
)
auto
ln2_mean
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Mean"
);
:
nullptr
;
auto
ln2_variance
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Variance"
);
auto
*
ln1_scale
=
auto
*
ln2_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Scale"
);
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Scale"
)
:
nullptr
;
auto
*
ln2_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Bias"
);
auto
*
ln1_bias
=
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Bias"
)
:
nullptr
;
auto
*
ln2_mean
=
!
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Mean"
)
:
nullptr
;
auto
*
ln2_variance
=
!
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Variance"
)
:
nullptr
;
auto
*
ln2_scale
=
!
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Scale"
)
:
nullptr
;
auto
*
ln2_bias
=
!
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Bias"
)
:
nullptr
;
auto
*
d_x
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_x
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_ln1_scale
=
auto
*
d_ln1_scale
=
pre_layer_norm
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Ln1Scale"
));
?
context
.
Output
<
framework
::
Tensor
>
(
auto
*
d_ln1_bias
=
framework
::
GradVarName
(
"Ln1Scale"
))
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Ln1Bias"
));
:
nullptr
;
auto
*
d_ln1_bias
=
pre_layer_norm
?
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Ln1Bias"
))
:
nullptr
;
auto
*
d_ln2_scale
=
auto
*
d_ln2_scale
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Ln2Scale"
));
pre_layer_norm
?
nullptr
:
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Ln2Scale"
));
auto
*
d_ln2_bias
=
auto
*
d_ln2_bias
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Ln2Bias"
));
pre_layer_norm
?
nullptr
:
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Ln2Bias"
));
auto
*
d_linear1_weight
=
context
.
Output
<
framework
::
Tensor
>
(
auto
*
d_linear1_weight
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Linear1Weight"
));
framework
::
GradVarName
(
"Linear1Weight"
));
auto
*
d_linear1_bias
=
context
.
Output
<
framework
::
Tensor
>
(
auto
*
d_linear1_bias
=
context
.
Output
<
framework
::
Tensor
>
(
...
@@ -328,7 +365,6 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
...
@@ -328,7 +365,6 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const
float
epsilon1
=
context
.
Attr
<
float
>
(
"ln1_epsilon"
);
const
float
epsilon1
=
context
.
Attr
<
float
>
(
"ln1_epsilon"
);
const
float
epsilon2
=
context
.
Attr
<
float
>
(
"ln2_epsilon"
);
const
float
epsilon2
=
context
.
Attr
<
float
>
(
"ln2_epsilon"
);
const
bool
pre_layer_norm
=
context
.
Attr
<
bool
>
(
"pre_layer_norm"
);
const
std
::
string
act_method
=
context
.
Attr
<
std
::
string
>
(
"act_method"
);
const
std
::
string
act_method
=
context
.
Attr
<
std
::
string
>
(
"act_method"
);
DropoutParam
dropout_param1
(
context
,
1
);
DropoutParam
dropout_param1
(
context
,
1
);
DropoutParam
dropout_param2
(
context
,
2
);
DropoutParam
dropout_param2
(
context
,
2
);
...
...
python/paddle/fluid/tests/unittests/test_fused_attention_op.py
浏览文件 @
ff3018d7
...
@@ -65,7 +65,7 @@ class TestFusedAttentionOp(OpTest):
...
@@ -65,7 +65,7 @@ class TestFusedAttentionOp(OpTest):
def
config
(
self
):
def
config
(
self
):
self
.
x_type
=
np
.
float32
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
Tru
e
self
.
pre_layer_norm
=
Fals
e
self
.
training
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
batch_size
=
8
...
@@ -213,11 +213,40 @@ class TestFusedAttentionOp(OpTest):
...
@@ -213,11 +213,40 @@ class TestFusedAttentionOp(OpTest):
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-5
)
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-5
)
class
TestFusedAttentionOpPreLn
(
TestFusedAttentionOp
):
def
config
(
self
):
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
weight_attr
=
None
self
.
bias_attr
=
None
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
self
.
query_length
,
self
.
query_length
def
test_fused_attention_op
(
self
):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-1
)
np
.
testing
.
assert_allclose
(
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-1
)
class
TestFusedAttentionOpFp16
(
TestFusedAttentionOp
):
class
TestFusedAttentionOpFp16
(
TestFusedAttentionOp
):
def
config
(
self
):
def
config
(
self
):
self
.
x_type
=
np
.
float16
self
.
x_type
=
np
.
float16
self
.
attn_mask_type
=
np
.
float64
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
Tru
e
self
.
pre_layer_norm
=
Fals
e
self
.
training
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
batch_size
=
8
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录