Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ae592233
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ae592233
编写于
10月 28, 2021
作者:
L
Li Min
提交者:
GitHub
10月 28, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix fused_attention_op and fused_feedforward_op bug when pre_layer_norm is false. (#36793) (#36816)
* Fix bug when pre_layer_norm is false.
上级
11b9f5f9
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
254 addition
and
144 deletion
+254
-144
paddle/fluid/operators/fused/fused_attention_op.cc
paddle/fluid/operators/fused/fused_attention_op.cc
+64
-37
paddle/fluid/operators/fused/fused_attention_op.cu
paddle/fluid/operators/fused/fused_attention_op.cu
+25
-19
paddle/fluid/operators/fused/fused_feedforward_op.cc
paddle/fluid/operators/fused/fused_feedforward_op.cc
+56
-44
paddle/fluid/operators/fused/fused_feedforward_op.cu
paddle/fluid/operators/fused/fused_feedforward_op.cu
+78
-42
python/paddle/fluid/tests/unittests/test_fused_attention_op.py
...n/paddle/fluid/tests/unittests/test_fused_attention_op.py
+31
-2
未找到文件。
paddle/fluid/operators/fused/fused_attention_op.cc
浏览文件 @
ae592233
...
@@ -37,12 +37,15 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -37,12 +37,15 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearBias"
),
"Input"
,
"OutLinearBias"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearBias"
),
"Input"
,
"OutLinearBias"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnMean"
),
"Output"
,
"LnMean"
,
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnMean"
),
"Output"
,
"LnMean"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnVariance"
),
"Output"
,
"LnVariance"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnVariance"
),
"Output"
,
"LnVariance"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnOut"
),
"Output"
,
"LnOut"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnOut"
),
"Output"
,
"LnOut"
,
"FusedAttentionOp"
);
}
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKVOut"
),
"Output"
,
"QKVOut"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKVOut"
),
"Output"
,
"QKVOut"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
...
@@ -101,9 +104,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -101,9 +104,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"input qkv_weight = [%s]"
,
"input qkv_weight = [%s]"
,
x_dim
,
y_dim
));
x_dim
,
y_dim
));
ctx
->
SetOutputDim
(
"LnMean"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
ctx
->
SetOutputDim
(
"LnVariance"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnMean"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnOut"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"LnVariance"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnOut"
,
ctx
->
GetInputDim
(
"X"
));
}
// [batch_size, seq_len, 3, num_head, head_size]
// [batch_size, seq_len, 3, num_head, head_size]
ctx
->
SetOutputDim
(
"QKVOut"
,
ctx
->
SetOutputDim
(
"QKVOut"
,
{
x_dim
[
0
],
x_dim
[
1
],
y_dim
[
0
],
y_dim
[
1
],
y_dim
[
2
]});
{
x_dim
[
0
],
x_dim
[
1
],
y_dim
[
0
],
y_dim
[
1
],
y_dim
[
2
]});
...
@@ -351,11 +356,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
...
@@ -351,11 +356,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx
->
GetInputDim
(
"Ln2Bias"
));
ctx
->
GetInputDim
(
"Ln2Bias"
));
}
}
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnMean"
),
"Input"
,
"LnMean"
,
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnVariance"
),
"Input"
,
"LnVariance"
,
"FusedAttentionGrad"
);
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnMean"
),
"Input"
,
"LnMean"
,
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnVariance"
),
"Input"
,
"LnVariance"
,
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnOut"
),
"Input"
,
"LnOut"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"LnOut"
),
"Input"
,
"LnOut"
,
"FusedAttentionGrad"
);
"FusedAttentionGrad"
);
}
}
...
@@ -370,13 +375,15 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
...
@@ -370,13 +375,15 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearBias"
),
"Input"
,
"OutLinearBias"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearBias"
),
"Input"
,
"OutLinearBias"
,
"FusedAttentionGrad"
);
"FusedAttentionGrad"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"LnScale"
)))
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnScale"
),
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"LnScale"
)))
{
ctx
->
GetInputDim
(
"LnScale"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnScale"
),
}
ctx
->
GetInputDim
(
"LnScale"
));
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"LnBias"
)))
{
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnBias"
),
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"LnBias"
)))
{
ctx
->
GetInputDim
(
"LnBias"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnBias"
),
ctx
->
GetInputDim
(
"LnBias"
));
}
}
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
...
@@ -390,8 +397,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
...
@@ -390,8 +397,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVBias"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVBias"
),
ctx
->
GetInputDim
(
"QKVBias"
));
ctx
->
GetInputDim
(
"QKVBias"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnOut"
),
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
ctx
->
GetInputDim
(
"LnOut"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnOut"
),
ctx
->
GetInputDim
(
"LnOut"
));
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"FMHAOut"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"FMHAOut"
),
ctx
->
GetInputDim
(
"FMHAOut"
));
ctx
->
GetInputDim
(
"FMHAOut"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKTVOut"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKTVOut"
),
...
@@ -442,16 +451,23 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -442,16 +451,23 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op
->
SetInput
(
"SrcMask"
,
this
->
Input
(
"SrcMask"
));
op
->
SetInput
(
"SrcMask"
,
this
->
Input
(
"SrcMask"
));
op
->
SetInput
(
"OutLinearW"
,
this
->
Input
(
"OutLinearW"
));
op
->
SetInput
(
"OutLinearW"
,
this
->
Input
(
"OutLinearW"
));
op
->
SetInput
(
"OutLinearBias"
,
this
->
Input
(
"OutLinearBias"
));
op
->
SetInput
(
"OutLinearBias"
,
this
->
Input
(
"OutLinearBias"
));
if
(
this
->
HasInput
(
"LnScale"
))
{
op
->
SetInput
(
"LnScale"
,
this
->
Input
(
"LnScale"
));
op
->
SetAttrMap
(
this
->
Attrs
());
op
->
SetOutput
(
framework
::
GradVarName
(
"LnScale"
),
bool
is_pre_layer_norm
=
this
->
InputGrad
(
"LnScale"
));
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"pre_layer_norm"
));
}
if
(
is_pre_layer_norm
)
{
if
(
this
->
HasInput
(
"LnBias"
))
{
if
(
this
->
HasInput
(
"LnScale"
))
{
op
->
SetInput
(
"LnBias"
,
this
->
Input
(
"LnBias"
));
op
->
SetInput
(
"LnScale"
,
this
->
Input
(
"LnScale"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"LnBias"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"LnScale"
),
this
->
InputGrad
(
"LnBias"
));
this
->
InputGrad
(
"LnScale"
));
}
if
(
this
->
HasInput
(
"LnBias"
))
{
op
->
SetInput
(
"LnBias"
,
this
->
Input
(
"LnBias"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"LnBias"
),
this
->
InputGrad
(
"LnBias"
));
}
}
}
if
(
this
->
HasInput
(
"Ln2Scale"
))
{
if
(
this
->
HasInput
(
"Ln2Scale"
))
{
op
->
SetInput
(
"Ln2Scale"
,
this
->
Input
(
"Ln2Scale"
));
op
->
SetInput
(
"Ln2Scale"
,
this
->
Input
(
"Ln2Scale"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Scale"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Scale"
),
...
@@ -473,9 +489,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -473,9 +489,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this
->
InputGrad
(
"OutLinearW"
));
this
->
InputGrad
(
"OutLinearW"
));
// use forward outputs as backward inputs.
// use forward outputs as backward inputs.
op
->
SetInput
(
"LnOut"
,
this
->
Output
(
"LnOut"
));
if
(
is_pre_layer_norm
)
{
op
->
SetInput
(
"LnMean"
,
this
->
Output
(
"LnMean"
));
if
(
this
->
HasOutput
(
"LnOut"
))
{
op
->
SetInput
(
"LnVariance"
,
this
->
Output
(
"LnVariance"
));
op
->
SetInput
(
"LnOut"
,
this
->
Output
(
"LnOut"
));
}
if
(
this
->
HasOutput
(
"LnMean"
))
{
op
->
SetInput
(
"LnMean"
,
this
->
Output
(
"LnMean"
));
}
if
(
this
->
HasOutput
(
"LnVariance"
))
{
op
->
SetInput
(
"LnVariance"
,
this
->
Output
(
"LnVariance"
));
}
}
op
->
SetInput
(
"QKVOut"
,
this
->
Output
(
"QKVOut"
));
op
->
SetInput
(
"QKVOut"
,
this
->
Output
(
"QKVOut"
));
op
->
SetInput
(
"QKVBiasOut"
,
this
->
Output
(
"QKVBiasOut"
));
op
->
SetInput
(
"QKVBiasOut"
,
this
->
Output
(
"QKVBiasOut"
));
op
->
SetInput
(
"TransposeOut2"
,
this
->
Output
(
"TransposeOut2"
));
op
->
SetInput
(
"TransposeOut2"
,
this
->
Output
(
"TransposeOut2"
));
...
@@ -496,7 +520,12 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -496,7 +520,12 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op
->
SetInput
(
"QKVOut"
,
this
->
Output
(
"QKVOut"
));
op
->
SetInput
(
"QKVOut"
,
this
->
Output
(
"QKVOut"
));
// backward outputs: dinput
// backward outputs: dinput
op
->
SetOutput
(
framework
::
GradVarName
(
"LnOut"
),
this
->
OutputGrad
(
"LnOut"
));
if
(
is_pre_layer_norm
)
{
if
(
this
->
HasOutput
(
"LnOut"
))
{
op
->
SetOutput
(
framework
::
GradVarName
(
"LnOut"
),
this
->
OutputGrad
(
"LnOut"
));
}
}
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVOut"
),
this
->
OutputGrad
(
"QKVOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVOut"
),
this
->
OutputGrad
(
"QKVOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVBiasOut"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVBiasOut"
),
this
->
OutputGrad
(
"QKVBiasOut"
));
this
->
OutputGrad
(
"QKVBiasOut"
));
...
@@ -517,8 +546,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -517,8 +546,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this
->
OutputGrad
(
"BiasDropoutResidualOut"
));
this
->
OutputGrad
(
"BiasDropoutResidualOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"OutLinearOut"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"OutLinearOut"
),
this
->
OutputGrad
(
"OutLinearOut"
));
this
->
OutputGrad
(
"OutLinearOut"
));
op
->
SetAttrMap
(
this
->
Attrs
());
}
}
};
};
...
...
paddle/fluid/operators/fused/fused_attention_op.cu
浏览文件 @
ae592233
...
@@ -97,9 +97,12 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -97,9 +97,12 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto
*
x_data
=
input_x
->
data
<
T
>
();
auto
*
x_data
=
input_x
->
data
<
T
>
();
auto
*
ln_scale_data
=
(
ln_scale
==
nullptr
?
nullptr
:
ln_scale
->
data
<
U
>
());
auto
*
ln_scale_data
=
(
ln_scale
==
nullptr
?
nullptr
:
ln_scale
->
data
<
U
>
());
auto
*
ln_bias_data
=
(
ln_bias
==
nullptr
?
nullptr
:
ln_bias
->
data
<
U
>
());
auto
*
ln_bias_data
=
(
ln_bias
==
nullptr
?
nullptr
:
ln_bias
->
data
<
U
>
());
auto
*
ln_mean_data
=
ln_mean
->
mutable_data
<
U
>
(
ctx
.
GetPlace
());
auto
*
ln_mean_data
=
auto
*
ln_var_data
=
ln_var
->
mutable_data
<
U
>
(
ctx
.
GetPlace
());
pre_layer_norm
?
ln_mean
->
mutable_data
<
U
>
(
ctx
.
GetPlace
())
:
nullptr
;
auto
*
ln_out_data
=
ln_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
ln_var_data
=
pre_layer_norm
?
ln_var
->
mutable_data
<
U
>
(
ctx
.
GetPlace
())
:
nullptr
;
auto
*
ln_out_data
=
pre_layer_norm
?
ln_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())
:
nullptr
;
auto
*
qkv_weight_data
=
qkv_weight
->
data
<
T
>
();
auto
*
qkv_weight_data
=
qkv_weight
->
data
<
T
>
();
auto
*
qkv_bias_data
=
qkv_bias
->
data
<
T
>
();
auto
*
qkv_bias_data
=
qkv_bias
->
data
<
T
>
();
...
@@ -243,9 +246,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -243,9 +246,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
out_linear_bias_data
=
out_linear_bias
->
data
<
T
>
();
auto
*
out_linear_bias_data
=
out_linear_bias
->
data
<
T
>
();
// fw output
// fw output
auto
*
ln_mean
=
ctx
.
Input
<
Tensor
>
(
"LnMean"
);
auto
*
ln_var
=
ctx
.
Input
<
Tensor
>
(
"LnVariance"
);
auto
*
ln_out
=
ctx
.
Input
<
Tensor
>
(
"LnOut"
);
auto
*
fmha_out
=
ctx
.
Input
<
Tensor
>
(
"FMHAOut"
);
auto
*
fmha_out
=
ctx
.
Input
<
Tensor
>
(
"FMHAOut"
);
auto
*
transpose_out_2
=
ctx
.
Input
<
Tensor
>
(
"TransposeOut2"
);
auto
*
transpose_out_2
=
ctx
.
Input
<
Tensor
>
(
"TransposeOut2"
);
auto
*
qk_out
=
ctx
.
Input
<
Tensor
>
(
"QKOut"
);
auto
*
qk_out
=
ctx
.
Input
<
Tensor
>
(
"QKOut"
);
...
@@ -260,9 +260,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -260,9 +260,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
dropout_mask_out
=
ctx
.
Input
<
Tensor
>
(
"DropoutMaskOut"
);
auto
*
dropout_mask_out
=
ctx
.
Input
<
Tensor
>
(
"DropoutMaskOut"
);
auto
*
bias_dropout_residual_out
=
auto
*
bias_dropout_residual_out
=
ctx
.
Input
<
Tensor
>
(
"BiasDropoutResidualOut"
);
ctx
.
Input
<
Tensor
>
(
"BiasDropoutResidualOut"
);
auto
*
ln_mean_data
=
ln_mean
->
data
<
U
>
();
auto
*
ln_var_data
=
ln_var
->
data
<
U
>
();
auto
*
ln_out_data
=
ln_out
->
data
<
T
>
();
auto
*
fmha_out_data
=
fmha_out
->
data
<
T
>
();
auto
*
fmha_out_data
=
fmha_out
->
data
<
T
>
();
auto
*
transpose_out_2_data
=
transpose_out_2
->
data
<
T
>
();
auto
*
transpose_out_2_data
=
transpose_out_2
->
data
<
T
>
();
auto
*
qk_out_data
=
qk_out
->
data
<
T
>
();
auto
*
qk_out_data
=
qk_out
->
data
<
T
>
();
...
@@ -277,7 +274,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -277,7 +274,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
// output's grad
// output's grad
auto
*
d_x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_ln_out
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"LnOut"
));
auto
*
d_qkv_out
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVOut"
));
auto
*
d_qkv_out
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVOut"
));
auto
*
d_qkv_bias_out
=
auto
*
d_qkv_bias_out
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVBiasOut"
));
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVBiasOut"
));
...
@@ -297,7 +293,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -297,7 +293,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
d_bias_dropout_residual_out
=
auto
*
d_bias_dropout_residual_out
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"BiasDropoutResidualOut"
));
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"BiasDropoutResidualOut"
));
auto
*
d_x_data
=
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_x_data
=
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_ln_out_data
=
d_ln_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_out_data
=
d_qkv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_out_data
=
d_qkv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_bias_out_data
=
d_qkv_bias_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_bias_out_data
=
d_qkv_bias_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qktv_out_data
=
d_qktv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qktv_out_data
=
d_qktv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
@@ -315,8 +310,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -315,8 +310,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_bias_dropout_residual_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_bias_dropout_residual_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// parameter grad
// parameter grad
auto
*
d_ln_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"LnScale"
));
auto
*
d_ln_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"LnBias"
));
auto
*
d_qkv_weight
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVW"
));
auto
*
d_qkv_weight
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVW"
));
auto
*
d_qkv_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVBias"
));
auto
*
d_qkv_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"QKVBias"
));
auto
*
d_out_linear_weight
=
auto
*
d_out_linear_weight
=
...
@@ -325,12 +318,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -325,12 +318,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"OutLinearBias"
));
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"OutLinearBias"
));
auto
*
d_ln_2_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Ln2Scale"
));
auto
*
d_ln_2_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Ln2Scale"
));
auto
*
d_ln_2_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Ln2Bias"
));
auto
*
d_ln_2_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Ln2Bias"
));
auto
*
d_ln_scale_data
=
(
d_ln_scale
==
nullptr
?
nullptr
:
d_ln_scale
->
mutable_data
<
U
>
(
ctx
.
GetPlace
()));
auto
*
d_ln_bias_data
=
(
d_ln_bias
==
nullptr
?
nullptr
:
d_ln_bias
->
mutable_data
<
U
>
(
ctx
.
GetPlace
()));
auto
*
d_qkv_weight_data
=
d_qkv_weight
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_weight_data
=
d_qkv_weight
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_bias_data
=
d_qkv_bias
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_bias_data
=
d_qkv_bias
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_out_linear_weight_data
=
auto
*
d_out_linear_weight_data
=
...
@@ -407,6 +395,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -407,6 +395,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
cudaMemcpyDeviceToDevice
);
cudaMemcpyDeviceToDevice
);
if
(
pre_layer_norm
)
{
if
(
pre_layer_norm
)
{
auto
*
ln_mean
=
ctx
.
Input
<
Tensor
>
(
"LnMean"
);
auto
*
ln_var
=
ctx
.
Input
<
Tensor
>
(
"LnVariance"
);
auto
*
ln_out
=
ctx
.
Input
<
Tensor
>
(
"LnOut"
);
auto
*
ln_mean_data
=
ln_mean
->
data
<
U
>
();
auto
*
ln_var_data
=
ln_var
->
data
<
U
>
();
auto
*
ln_out_data
=
ln_out
->
data
<
T
>
();
auto
*
d_ln_out
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"LnOut"
));
auto
*
d_ln_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"LnScale"
));
auto
*
d_ln_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"LnBias"
));
auto
*
d_ln_out_data
=
d_ln_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_ln_scale_data
=
(
d_ln_scale
==
nullptr
?
nullptr
:
d_ln_scale
->
mutable_data
<
U
>
(
ctx
.
GetPlace
()));
auto
*
d_ln_bias_data
=
(
d_ln_bias
==
nullptr
?
nullptr
:
d_ln_bias
->
mutable_data
<
U
>
(
ctx
.
GetPlace
()));
qkv_compute
.
ComputeBackward
(
ln_out_data
,
qkv_weight_data
,
qkv_compute
.
ComputeBackward
(
ln_out_data
,
qkv_weight_data
,
d_qkv_bias_out_data
,
d_ln_out_data
,
d_qkv_bias_out_data
,
d_ln_out_data
,
d_qkv_weight_data
,
d_qkv_bias_data
);
d_qkv_weight_data
,
d_qkv_bias_data
);
...
...
paddle/fluid/operators/fused/fused_feedforward_op.cc
浏览文件 @
ae592233
...
@@ -41,18 +41,8 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
...
@@ -41,18 +41,8 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
"fused_feedforward"
);
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Dropout2Mask"
),
"Output"
,
"Dropout2Mask"
,
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Dropout2Mask"
),
"Output"
,
"Dropout2Mask"
,
"fused_feedforward"
);
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln1Mean"
),
"Output"
,
"Ln1Mean"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln1Variance"
),
"Output"
,
"Ln1Variance"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln2Mean"
),
"Output"
,
"Ln2Mean"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln2Variance"
),
"Output"
,
"Ln2Variance"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Linear1Out"
),
"Output"
,
"Linear1Out"
,
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Linear1Out"
),
"Output"
,
"Linear1Out"
,
"fused_feedforward"
);
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln1Out"
),
"Output"
,
"Ln1Out"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Dropout1Out"
),
"Output"
,
"Dropout1Out"
,
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Dropout1Out"
),
"Output"
,
"Dropout1Out"
,
"fused_feedforward"
);
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Dropout2Out"
),
"Output"
,
"Dropout2Out"
,
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Dropout2Out"
),
"Output"
,
"Dropout2Out"
,
...
@@ -76,7 +66,6 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
...
@@ -76,7 +66,6 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
}
}
context
->
SetOutputDim
(
"Dropout1Out"
,
tmp_dim_x
);
context
->
SetOutputDim
(
"Dropout1Out"
,
tmp_dim_x
);
context
->
SetOutputDim
(
"Linear1Out"
,
tmp_dim_x
);
context
->
SetOutputDim
(
"Linear1Out"
,
tmp_dim_x
);
context
->
SetOutputDim
(
"Ln1Out"
,
dim_x
);
context
->
SetOutputDim
(
"Dropout2Out"
,
dim_x
);
context
->
SetOutputDim
(
"Dropout2Out"
,
dim_x
);
if
(
context
->
Attrs
().
Get
<
bool
>
(
"dropout2_is_test"
)
==
false
)
{
if
(
context
->
Attrs
().
Get
<
bool
>
(
"dropout2_is_test"
)
==
false
)
{
...
@@ -84,10 +73,25 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
...
@@ -84,10 +73,25 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
}
}
framework
::
DDim
mean_dim
=
framework
::
DDim
mean_dim
=
framework
::
make_ddim
({
mat_dim_x
.
batch_size_
*
mat_dim_x
.
height_
});
framework
::
make_ddim
({
mat_dim_x
.
batch_size_
*
mat_dim_x
.
height_
});
context
->
SetOutputDim
(
"Ln1Mean"
,
mean_dim
);
bool
pre_layer_norm
=
context
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
);
context
->
SetOutputDim
(
"Ln1Variance"
,
mean_dim
);
if
(
pre_layer_norm
)
{
context
->
SetOutputDim
(
"Ln2Mean"
,
mean_dim
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln1Mean"
),
"Output"
,
"Ln1Mean"
,
context
->
SetOutputDim
(
"Ln2Variance"
,
mean_dim
);
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln1Variance"
),
"Output"
,
"Ln1Variance"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln1Out"
),
"Output"
,
"Ln1Out"
,
"fused_feedforward"
);
context
->
SetOutputDim
(
"Ln1Out"
,
dim_x
);
context
->
SetOutputDim
(
"Ln1Mean"
,
mean_dim
);
context
->
SetOutputDim
(
"Ln1Variance"
,
mean_dim
);
}
else
{
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln2Mean"
),
"Output"
,
"Ln2Mean"
,
"fused_feedforward"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Ln2Variance"
),
"Output"
,
"Ln2Variance"
,
"fused_feedforward"
);
context
->
SetOutputDim
(
"Ln2Mean"
,
mean_dim
);
context
->
SetOutputDim
(
"Ln2Variance"
,
mean_dim
);
}
context
->
ShareLoD
(
"X"
,
"Out"
);
context
->
ShareLoD
(
"X"
,
"Out"
);
}
}
...
@@ -218,14 +222,13 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
...
@@ -218,14 +222,13 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ
(
ctx
->
Attrs
().
Get
<
bool
>
(
"dropout2_is_test"
),
false
,
PADDLE_ENFORCE_EQ
(
ctx
->
Attrs
().
Get
<
bool
>
(
"dropout2_is_test"
),
false
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"GradOp is only callable when is_test is false"
));
"GradOp is only callable when is_test is false"
));
bool
pre_layer_norm
=
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout1Mask"
),
"Input"
,
"Dropout1Mask"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout1Mask"
),
"Input"
,
"Dropout1Mask"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout2Mask"
),
"Input"
,
"Dropout1Mask"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout2Mask"
),
"Input"
,
"Dropout1Mask"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Linear1Out"
),
"Input"
,
"Linear1Out"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Linear1Out"
),
"Input"
,
"Linear1Out"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln1Out"
),
"Input"
,
"Ln1Out"
,
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout1Out"
),
"Input"
,
"Dropout1Out"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout1Out"
),
"Input"
,
"Dropout1Out"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout2Out"
),
"Input"
,
"Dropout2Out"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Dropout2Out"
),
"Input"
,
"Dropout2Out"
,
...
@@ -234,14 +237,19 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
...
@@ -234,14 +237,19 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Linear2Weight"
),
"Input"
,
"Linear2Weight"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Linear2Weight"
),
"Input"
,
"Linear2Weight"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln1Mean"
),
"Input"
,
"Ln1Mean"
,
if
(
pre_layer_norm
)
{
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln1Mean"
),
"Input"
,
"Ln1Mean"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln1Variance"
),
"Input"
,
"Ln1Variance"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln1Variance"
),
"Input"
,
"Ln1Variance"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln2Mean"
),
"Input"
,
"Ln2Mean"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln1Out"
),
"Input"
,
"Ln1Out"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln2Variance"
),
"Input"
,
"Ln2Variance"
,
"FusedFeedForwardGrad"
);
"FusedFeedForwardGrad"
);
}
else
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln2Mean"
),
"Input"
,
"Ln2Mean"
,
"FusedFeedForwardGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ln2Variance"
),
"Input"
,
"Ln2Variance"
,
"FusedFeedForwardGrad"
);
}
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
framework
::
GradVarName
(
"Out"
),
"FusedFeedForwardGrad"
);
framework
::
GradVarName
(
"Out"
),
"FusedFeedForwardGrad"
);
...
@@ -299,30 +307,36 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
...
@@ -299,30 +307,36 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
op
->
SetInput
(
"Linear1Weight"
,
this
->
Input
(
"Linear1Weight"
));
op
->
SetInput
(
"Linear1Weight"
,
this
->
Input
(
"Linear1Weight"
));
op
->
SetInput
(
"Linear1Bias"
,
this
->
Input
(
"Linear1Bias"
));
op
->
SetInput
(
"Linear1Bias"
,
this
->
Input
(
"Linear1Bias"
));
op
->
SetInput
(
"Linear2Weight"
,
this
->
Input
(
"Linear2Weight"
));
op
->
SetInput
(
"Linear2Weight"
,
this
->
Input
(
"Linear2Weight"
));
op
->
SetInput
(
"Ln1Scale"
,
this
->
Input
(
"Ln1Scale"
));
op
->
SetInput
(
"Ln1Bias"
,
this
->
Input
(
"Ln1Bias"
));
op
->
SetInput
(
"Ln2Scale"
,
this
->
Input
(
"Ln2Scale"
));
op
->
SetInput
(
"Ln2Bias"
,
this
->
Input
(
"Ln2Bias"
));
op
->
SetInput
(
"Dropout1Mask"
,
this
->
Output
(
"Dropout1Mask"
));
op
->
SetInput
(
"Dropout1Mask"
,
this
->
Output
(
"Dropout1Mask"
));
op
->
SetInput
(
"Dropout2Mask"
,
this
->
Output
(
"Dropout2Mask"
));
op
->
SetInput
(
"Dropout2Mask"
,
this
->
Output
(
"Dropout2Mask"
));
op
->
SetInput
(
"Linear1Out"
,
this
->
Output
(
"Linear1Out"
));
op
->
SetInput
(
"Linear1Out"
,
this
->
Output
(
"Linear1Out"
));
op
->
SetInput
(
"Ln1Out"
,
this
->
Output
(
"Ln1Out"
));
op
->
SetInput
(
"Ln1Mean"
,
this
->
Output
(
"Ln1Mean"
));
op
->
SetInput
(
"Ln1Variance"
,
this
->
Output
(
"Ln1Variance"
));
op
->
SetInput
(
"Ln2Mean"
,
this
->
Output
(
"Ln2Mean"
));
op
->
SetInput
(
"Ln2Variance"
,
this
->
Output
(
"Ln2Variance"
));
op
->
SetInput
(
"Dropout1Out"
,
this
->
Output
(
"Dropout1Out"
));
op
->
SetInput
(
"Dropout1Out"
,
this
->
Output
(
"Dropout1Out"
));
op
->
SetInput
(
"Dropout2Out"
,
this
->
Output
(
"Dropout2Out"
));
op
->
SetInput
(
"Dropout2Out"
,
this
->
Output
(
"Dropout2Out"
));
op
->
SetAttrMap
(
this
->
Attrs
());
bool
pre_layer_norm
=
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"pre_layer_norm"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln1Scale"
),
if
(
pre_layer_norm
)
{
this
->
InputGrad
(
"Ln1Scale"
));
op
->
SetInput
(
"Ln1Scale"
,
this
->
Input
(
"Ln1Scale"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln1Bias"
),
op
->
SetInput
(
"Ln1Bias"
,
this
->
Input
(
"Ln1Bias"
));
this
->
InputGrad
(
"Ln1Bias"
));
op
->
SetInput
(
"Ln1Out"
,
this
->
Output
(
"Ln1Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Scale"
),
op
->
SetInput
(
"Ln1Mean"
,
this
->
Output
(
"Ln1Mean"
));
this
->
InputGrad
(
"Ln2Scale"
));
op
->
SetInput
(
"Ln1Variance"
,
this
->
Output
(
"Ln1Variance"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Bias"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln1Scale"
),
this
->
InputGrad
(
"Ln2Bias"
));
this
->
InputGrad
(
"Ln1Scale"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln1Bias"
),
this
->
InputGrad
(
"Ln1Bias"
));
}
else
{
op
->
SetInput
(
"Ln2Scale"
,
this
->
Input
(
"Ln2Scale"
));
op
->
SetInput
(
"Ln2Bias"
,
this
->
Input
(
"Ln2Bias"
));
op
->
SetInput
(
"Ln2Mean"
,
this
->
Output
(
"Ln2Mean"
));
op
->
SetInput
(
"Ln2Variance"
,
this
->
Output
(
"Ln2Variance"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Scale"
),
this
->
InputGrad
(
"Ln2Scale"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Ln2Bias"
),
this
->
InputGrad
(
"Ln2Bias"
));
}
op
->
SetOutput
(
framework
::
GradVarName
(
"Linear1Weight"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Linear1Weight"
),
this
->
InputGrad
(
"Linear1Weight"
));
this
->
InputGrad
(
"Linear1Weight"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Linear1Bias"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Linear1Bias"
),
...
@@ -334,8 +348,6 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
...
@@ -334,8 +348,6 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
op
->
SetOutput
(
framework
::
GradVarName
(
"Linear2Bias"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"Linear2Bias"
),
this
->
InputGrad
(
"Linear2Bias"
));
this
->
InputGrad
(
"Linear2Bias"
));
}
}
op
->
SetAttrMap
(
this
->
Attrs
());
}
}
};
};
...
...
paddle/fluid/operators/fused/fused_feedforward_op.cu
浏览文件 @
ae592233
...
@@ -113,26 +113,40 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
...
@@ -113,26 +113,40 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
auto
*
linear1_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Bias"
);
auto
*
linear1_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Bias"
);
auto
*
linear2_weight
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear2Weight"
);
auto
*
linear2_weight
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear2Weight"
);
auto
*
linear2_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear2Bias"
);
auto
*
linear2_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear2Bias"
);
auto
*
ln1_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Scale"
);
const
bool
pre_layer_norm
=
context
.
Attr
<
bool
>
(
"pre_layer_norm"
);
auto
*
ln1_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Bias"
);
auto
*
ln2_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Scale"
);
auto
*
ln1_scale
=
auto
*
ln2_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Bias"
);
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Scale"
)
:
nullptr
;
auto
*
ln1_bias
=
auto
*
ln1_mean
=
context
.
Output
<
framework
::
Tensor
>
(
"Ln1Mean"
);
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Bias"
)
:
nullptr
;
auto
*
ln1_variance
=
context
.
Output
<
framework
::
Tensor
>
(
"Ln1Variance"
);
auto
*
ln2_scale
=
!
pre_layer_norm
auto
*
ln2_mean
=
context
.
Output
<
framework
::
Tensor
>
(
"Ln2Mean"
);
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Scale"
)
auto
*
ln2_variance
=
context
.
Output
<
framework
::
Tensor
>
(
"Ln2Variance"
);
:
nullptr
;
auto
*
ln2_bias
=
!
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Bias"
)
:
nullptr
;
auto
*
ln1_mean
=
pre_layer_norm
?
context
.
Output
<
framework
::
Tensor
>
(
"Ln1Mean"
)
:
nullptr
;
auto
*
ln1_variance
=
pre_layer_norm
?
context
.
Output
<
framework
::
Tensor
>
(
"Ln1Variance"
)
:
nullptr
;
auto
*
ln2_mean
=
!
pre_layer_norm
?
context
.
Output
<
framework
::
Tensor
>
(
"Ln2Mean"
)
:
nullptr
;
auto
*
ln2_variance
=
!
pre_layer_norm
?
context
.
Output
<
framework
::
Tensor
>
(
"Ln2Variance"
)
:
nullptr
;
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
dropout1_mask
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout1Mask"
);
auto
*
dropout1_mask
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout1Mask"
);
auto
*
dropout2_mask
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout2Mask"
);
auto
*
dropout2_mask
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout2Mask"
);
auto
*
linear1_out
=
context
.
Output
<
framework
::
Tensor
>
(
"Linear1Out"
);
auto
*
linear1_out
=
context
.
Output
<
framework
::
Tensor
>
(
"Linear1Out"
);
auto
*
ln1_out
=
context
.
Output
<
framework
::
Tensor
>
(
"Ln1Out"
);
auto
*
ln1_out
=
pre_layer_norm
?
context
.
Output
<
framework
::
Tensor
>
(
"Ln1Out"
)
:
nullptr
;
auto
*
dropout1_out
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout1Out"
);
auto
*
dropout1_out
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout1Out"
);
auto
*
dropout2_out
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout2Out"
);
auto
*
dropout2_out
=
context
.
Output
<
framework
::
Tensor
>
(
"Dropout2Out"
);
const
std
::
string
act_method
=
context
.
Attr
<
std
::
string
>
(
"act_method"
);
const
std
::
string
act_method
=
context
.
Attr
<
std
::
string
>
(
"act_method"
);
const
bool
pre_layer_norm
=
context
.
Attr
<
bool
>
(
"pre_layer_norm"
);
const
float
epsilon1
=
context
.
Attr
<
float
>
(
"ln1_epsilon"
);
const
float
epsilon1
=
context
.
Attr
<
float
>
(
"ln1_epsilon"
);
const
float
epsilon2
=
context
.
Attr
<
float
>
(
"ln2_epsilon"
);
const
float
epsilon2
=
context
.
Attr
<
float
>
(
"ln2_epsilon"
);
...
@@ -144,12 +158,16 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
...
@@ -144,12 +158,16 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
out
->
mutable_data
<
T
>
(
place
);
out
->
mutable_data
<
T
>
(
place
);
dropout1_mask
->
mutable_data
<
uint8_t
>
(
place
);
dropout1_mask
->
mutable_data
<
uint8_t
>
(
place
);
dropout2_mask
->
mutable_data
<
uint8_t
>
(
place
);
dropout2_mask
->
mutable_data
<
uint8_t
>
(
place
);
ln1_mean
->
mutable_data
<
U
>
(
place
);
if
(
pre_layer_norm
)
{
ln1_variance
->
mutable_data
<
U
>
(
place
);
ln1_mean
->
mutable_data
<
U
>
(
place
);
ln2_mean
->
mutable_data
<
U
>
(
place
);
ln1_variance
->
mutable_data
<
U
>
(
place
);
ln2_variance
->
mutable_data
<
U
>
(
place
);
ln1_out
->
mutable_data
<
T
>
(
place
);
}
else
{
ln2_mean
->
mutable_data
<
U
>
(
place
);
ln2_variance
->
mutable_data
<
U
>
(
place
);
}
linear1_out
->
mutable_data
<
T
>
(
place
);
linear1_out
->
mutable_data
<
T
>
(
place
);
ln1_out
->
mutable_data
<
T
>
(
place
);
dropout1_out
->
mutable_data
<
T
>
(
place
);
dropout1_out
->
mutable_data
<
T
>
(
place
);
dropout2_out
->
mutable_data
<
T
>
(
place
);
dropout2_out
->
mutable_data
<
T
>
(
place
);
...
@@ -193,16 +211,16 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
...
@@ -193,16 +211,16 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const
framework
::
Tensor
&
d_out
,
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
d_out
,
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
dropout1_mask
,
const
framework
::
Tensor
&
dropout1_mask
,
const
framework
::
Tensor
&
dropout2_mask
,
const
framework
::
Tensor
&
dropout2_mask
,
const
framework
::
Tensor
&
linear1_out
,
const
framework
::
Tensor
&
ln1_out
,
const
framework
::
Tensor
&
linear1_out
,
const
framework
::
Tensor
*
ln1_out
,
const
framework
::
Tensor
&
dropout1_out
,
const
framework
::
Tensor
&
dropout1_out
,
const
framework
::
Tensor
&
dropout2_out
,
const
framework
::
Tensor
&
dropout2_out
,
const
framework
::
Tensor
&
linear1_weight
,
const
framework
::
Tensor
&
linear1_weight
,
const
framework
::
Tensor
*
linear1_bias
,
const
framework
::
Tensor
*
linear1_bias
,
const
framework
::
Tensor
&
linear2_weight
,
const
framework
::
Tensor
&
linear2_weight
,
const
framework
::
Tensor
*
ln1_gamma
,
const
framework
::
Tensor
*
ln1_beta
,
const
framework
::
Tensor
*
ln1_gamma
,
const
framework
::
Tensor
*
ln1_beta
,
const
framework
::
Tensor
&
ln1_mean
,
const
framework
::
Tensor
&
ln1_variance
,
const
framework
::
Tensor
*
ln1_mean
,
const
framework
::
Tensor
*
ln1_variance
,
const
framework
::
Tensor
*
ln2_gamma
,
const
framework
::
Tensor
*
ln2_beta
,
const
framework
::
Tensor
*
ln2_gamma
,
const
framework
::
Tensor
*
ln2_beta
,
const
framework
::
Tensor
&
ln2_mean
,
const
framework
::
Tensor
&
ln2_variance
,
const
framework
::
Tensor
*
ln2_mean
,
const
framework
::
Tensor
*
ln2_variance
,
framework
::
Tensor
*
d_x
,
framework
::
Tensor
*
d_linear1_weight
,
framework
::
Tensor
*
d_x
,
framework
::
Tensor
*
d_linear1_weight
,
framework
::
Tensor
*
d_linear1_bias
,
framework
::
Tensor
*
d_linear2_weight
,
framework
::
Tensor
*
d_linear1_bias
,
framework
::
Tensor
*
d_linear2_weight
,
framework
::
Tensor
*
d_linear2_bias
,
framework
::
Tensor
*
d_ln1_gamma
,
framework
::
Tensor
*
d_linear2_bias
,
framework
::
Tensor
*
d_ln1_gamma
,
...
@@ -252,8 +270,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
...
@@ -252,8 +270,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
}
else
{
}
else
{
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBiasGrad
(
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBiasGrad
(
ctx
,
d_out
.
data
<
T
>
(),
dropout2_out
.
data
<
T
>
(),
ctx
,
d_out
.
data
<
T
>
(),
dropout2_out
.
data
<
T
>
(),
dropout2_mask
.
data
<
uint8_t
>
(),
ln2_gamma_ptr
,
ln2_mean
.
data
<
U
>
(),
dropout2_mask
.
data
<
uint8_t
>
(),
ln2_gamma_ptr
,
ln2_mean
->
data
<
U
>
(),
ln2_variance
.
data
<
U
>
(),
d_dropout2_out
.
data
<
T
>
(),
d_ln2_gamma_ptr
,
ln2_variance
->
data
<
U
>
(),
d_dropout2_out
.
data
<
T
>
(),
d_ln2_gamma_ptr
,
d_ln2_beta_ptr
,
d_linear2_out
.
data
<
T
>
(),
d_linear2_bias_ptr
,
d_ln2_beta_ptr
,
d_linear2_out
.
data
<
T
>
(),
d_linear2_bias_ptr
,
d_residual
.
data
<
T
>
());
d_residual
.
data
<
T
>
());
}
}
...
@@ -273,13 +291,13 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
...
@@ -273,13 +291,13 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
if
(
pre_layer_norm
)
{
if
(
pre_layer_norm
)
{
framework
::
Tensor
d_ln1_out
;
framework
::
Tensor
d_ln1_out
;
d_ln1_out
.
mutable_data
<
T
>
({
bsz_seq
,
d_model
},
place
);
d_ln1_out
.
mutable_data
<
T
>
({
bsz_seq
,
d_model
},
place
);
MatMulGrad
(
ctx
,
d_linear1_out
,
ln1_out
,
linear1_weight
,
&
d_ln1_out
,
MatMulGrad
(
ctx
,
d_linear1_out
,
*
ln1_out
,
linear1_weight
,
&
d_ln1_out
,
d_linear1_weight
);
d_linear1_weight
);
pre_layernorm_helper
.
LayerNormGrad
(
ctx
,
d_ln1_out
.
data
<
T
>
(),
x
.
data
<
T
>
(),
pre_layernorm_helper
.
LayerNormGrad
(
ln1_gamma_ptr
,
ln1_mean
.
data
<
U
>
()
,
ctx
,
d_ln1_out
.
data
<
T
>
(),
x
.
data
<
T
>
(),
ln1_gamma_ptr
,
ln1_variance
.
data
<
U
>
(),
d_x
->
data
<
T
>
(),
ln1_mean
->
data
<
U
>
(),
ln1_variance
->
data
<
U
>
(),
d_x
->
data
<
T
>
(),
d_ln1_gamma_ptr
,
d_ln1_beta_ptr
);
d_ln1_gamma_ptr
,
d_ln1_beta_ptr
);
}
else
{
}
else
{
MatMulGrad
(
ctx
,
d_linear1_out
,
x
,
linear1_weight
,
d_x
,
d_linear1_weight
);
MatMulGrad
(
ctx
,
d_linear1_out
,
x
,
linear1_weight
,
d_x
,
d_linear1_weight
);
}
}
...
@@ -290,33 +308,52 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
...
@@ -290,33 +308,52 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
auto
d_out
=
auto
d_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
*
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
x
=
*
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
x
=
*
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
bool
pre_layer_norm
=
context
.
Attr
<
bool
>
(
"pre_layer_norm"
);
auto
dropout1_mask
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout1Mask"
);
auto
dropout1_mask
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout1Mask"
);
auto
dropout2_mask
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout2Mask"
);
auto
dropout2_mask
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout2Mask"
);
auto
linear1_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Out"
);
auto
linear1_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Out"
);
auto
ln1_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Out"
);
auto
*
ln1_out
=
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Out"
)
:
nullptr
;
auto
dropout1_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout1Out"
);
auto
dropout1_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout1Out"
);
auto
dropout2_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout2Out"
);
auto
dropout2_out
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Dropout2Out"
);
auto
linear1_weight
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Weight"
);
auto
linear1_weight
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Weight"
);
auto
*
linear1_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Bias"
);
auto
*
linear1_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Linear1Bias"
);
auto
linear2_weight
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Linear2Weight"
);
auto
linear2_weight
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Linear2Weight"
);
auto
ln1_mean
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Mean"
);
auto
*
ln1_mean
=
auto
ln1_variance
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Variance"
);
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Mean"
)
:
nullptr
;
auto
*
ln1_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Scale"
);
auto
*
ln1_variance
=
pre_layer_norm
auto
*
ln1_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Bias"
);
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Variance"
)
auto
ln2_mean
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Mean"
);
:
nullptr
;
auto
ln2_variance
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Variance"
);
auto
*
ln1_scale
=
auto
*
ln2_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Scale"
);
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Scale"
)
:
nullptr
;
auto
*
ln2_bias
=
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Bias"
);
auto
*
ln1_bias
=
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln1Bias"
)
:
nullptr
;
auto
*
ln2_mean
=
!
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Mean"
)
:
nullptr
;
auto
*
ln2_variance
=
!
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Variance"
)
:
nullptr
;
auto
*
ln2_scale
=
!
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Scale"
)
:
nullptr
;
auto
*
ln2_bias
=
!
pre_layer_norm
?
context
.
Input
<
framework
::
Tensor
>
(
"Ln2Bias"
)
:
nullptr
;
auto
*
d_x
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_x
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_ln1_scale
=
auto
*
d_ln1_scale
=
pre_layer_norm
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Ln1Scale"
));
?
context
.
Output
<
framework
::
Tensor
>
(
auto
*
d_ln1_bias
=
framework
::
GradVarName
(
"Ln1Scale"
))
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Ln1Bias"
));
:
nullptr
;
auto
*
d_ln1_bias
=
pre_layer_norm
?
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Ln1Bias"
))
:
nullptr
;
auto
*
d_ln2_scale
=
auto
*
d_ln2_scale
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Ln2Scale"
));
pre_layer_norm
?
nullptr
:
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Ln2Scale"
));
auto
*
d_ln2_bias
=
auto
*
d_ln2_bias
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Ln2Bias"
));
pre_layer_norm
?
nullptr
:
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Ln2Bias"
));
auto
*
d_linear1_weight
=
context
.
Output
<
framework
::
Tensor
>
(
auto
*
d_linear1_weight
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Linear1Weight"
));
framework
::
GradVarName
(
"Linear1Weight"
));
auto
*
d_linear1_bias
=
context
.
Output
<
framework
::
Tensor
>
(
auto
*
d_linear1_bias
=
context
.
Output
<
framework
::
Tensor
>
(
...
@@ -328,7 +365,6 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
...
@@ -328,7 +365,6 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const
float
epsilon1
=
context
.
Attr
<
float
>
(
"ln1_epsilon"
);
const
float
epsilon1
=
context
.
Attr
<
float
>
(
"ln1_epsilon"
);
const
float
epsilon2
=
context
.
Attr
<
float
>
(
"ln2_epsilon"
);
const
float
epsilon2
=
context
.
Attr
<
float
>
(
"ln2_epsilon"
);
const
bool
pre_layer_norm
=
context
.
Attr
<
bool
>
(
"pre_layer_norm"
);
const
std
::
string
act_method
=
context
.
Attr
<
std
::
string
>
(
"act_method"
);
const
std
::
string
act_method
=
context
.
Attr
<
std
::
string
>
(
"act_method"
);
DropoutParam
dropout_param1
(
context
,
1
);
DropoutParam
dropout_param1
(
context
,
1
);
DropoutParam
dropout_param2
(
context
,
2
);
DropoutParam
dropout_param2
(
context
,
2
);
...
...
python/paddle/fluid/tests/unittests/test_fused_attention_op.py
浏览文件 @
ae592233
...
@@ -65,7 +65,7 @@ class TestFusedAttentionOp(OpTest):
...
@@ -65,7 +65,7 @@ class TestFusedAttentionOp(OpTest):
def
config
(
self
):
def
config
(
self
):
self
.
x_type
=
np
.
float32
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
Tru
e
self
.
pre_layer_norm
=
Fals
e
self
.
training
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
batch_size
=
8
...
@@ -213,11 +213,40 @@ class TestFusedAttentionOp(OpTest):
...
@@ -213,11 +213,40 @@ class TestFusedAttentionOp(OpTest):
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-5
)
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-5
)
class
TestFusedAttentionOpPreLn
(
TestFusedAttentionOp
):
def
config
(
self
):
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
weight_attr
=
None
self
.
bias_attr
=
None
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
self
.
query_length
,
self
.
query_length
def
test_fused_attention_op
(
self
):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-1
)
np
.
testing
.
assert_allclose
(
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-1
)
class
TestFusedAttentionOpFp16
(
TestFusedAttentionOp
):
class
TestFusedAttentionOpFp16
(
TestFusedAttentionOp
):
def
config
(
self
):
def
config
(
self
):
self
.
x_type
=
np
.
float16
self
.
x_type
=
np
.
float16
self
.
attn_mask_type
=
np
.
float64
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
Tru
e
self
.
pre_layer_norm
=
Fals
e
self
.
training
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
batch_size
=
8
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录