Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1a8786cf
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1a8786cf
编写于
11月 23, 2021
作者:
L
Li Min
提交者:
GitHub
11月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add support bias is none for fused_attention op. (#37411)
Add support for bias is none for fused_attention op.
上级
4812eda5
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
173 addition
and
71 deletion
+173
-71
paddle/fluid/operators/fused/fused_attention_op.cc
paddle/fluid/operators/fused/fused_attention_op.cc
+47
-30
paddle/fluid/operators/fused/fused_attention_op.cu
paddle/fluid/operators/fused/fused_attention_op.cu
+76
-31
python/paddle/fluid/tests/unittests/test_fused_attention_op.py
...n/paddle/fluid/tests/unittests/test_fused_attention_op.py
+47
-10
python/paddle/incubate/nn/functional/fused_transformer.py
python/paddle/incubate/nn/functional/fused_transformer.py
+3
-0
未找到文件。
paddle/fluid/operators/fused/fused_attention_op.cc
浏览文件 @
1a8786cf
...
@@ -28,12 +28,8 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -28,12 +28,8 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVW"
),
"Input"
,
"QKVW"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVW"
),
"Input"
,
"QKVW"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVBias"
),
"Input"
,
"QKVBias"
,
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearW"
),
"Input"
,
"OutLinearW"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearW"
),
"Input"
,
"OutLinearW"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearBias"
),
"Input"
,
"OutLinearBias"
,
"FusedAttentionOp"
);
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnMean"
),
"Output"
,
"LnMean"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"LnMean"
),
"Output"
,
"LnMean"
,
...
@@ -54,8 +50,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -54,8 +50,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKVOut"
),
"Output"
,
"QKVOut"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKVOut"
),
"Output"
,
"QKVOut"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
if
(
ctx
->
HasInput
(
"QKVBias"
))
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKVBiasOut"
),
"Output"
,
"QKVBiasOut"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKVBiasOut"
),
"Output"
,
"QKVBiasOut"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
}
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"TransposeOut2"
),
"Output"
,
"TransposeOut2"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"TransposeOut2"
),
"Output"
,
"TransposeOut2"
,
"FusedAttentionOp"
);
"FusedAttentionOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKOut"
),
"Output"
,
"QKOut"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKOut"
),
"Output"
,
"QKOut"
,
...
@@ -107,6 +105,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -107,6 +105,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"input qkv_weight = [%s]"
,
"input qkv_weight = [%s]"
,
x_dim
,
y_dim
));
x_dim
,
y_dim
));
PADDLE_ENFORCE_EQ
(
y_dim
[
1
]
*
y_dim
[
2
],
y_dim
[
3
],
platform
::
errors
::
InvalidArgument
(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"
));
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
ctx
->
SetOutputDim
(
"LnMean"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnMean"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnVariance"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
ctx
->
SetOutputDim
(
"LnVariance"
,
{
x_dim
[
0
]
*
x_dim
[
1
]});
...
@@ -119,8 +124,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
...
@@ -119,8 +124,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
// [batch_size, seq_len, 3, num_head, head_size]
// [batch_size, seq_len, 3, num_head, head_size]
ctx
->
SetOutputDim
(
"QKVOut"
,
ctx
->
SetOutputDim
(
"QKVOut"
,
{
x_dim
[
0
],
x_dim
[
1
],
y_dim
[
0
],
y_dim
[
1
],
y_dim
[
2
]});
{
x_dim
[
0
],
x_dim
[
1
],
y_dim
[
0
],
y_dim
[
1
],
y_dim
[
2
]});
if
(
ctx
->
HasInput
(
"QKVBias"
))
{
ctx
->
SetOutputDim
(
"QKVBiasOut"
,
ctx
->
SetOutputDim
(
"QKVBiasOut"
,
{
x_dim
[
0
],
x_dim
[
1
],
y_dim
[
0
],
y_dim
[
1
],
y_dim
[
2
]});
{
x_dim
[
0
],
x_dim
[
1
],
y_dim
[
0
],
y_dim
[
1
],
y_dim
[
2
]});
}
// [3, batch_size, num_head, seq_len, head_size]
// [3, batch_size, num_head, seq_len, head_size]
ctx
->
SetOutputDim
(
"TransposeOut2"
,
ctx
->
SetOutputDim
(
"TransposeOut2"
,
{
y_dim
[
0
],
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
y_dim
[
2
]});
{
y_dim
[
0
],
x_dim
[
0
],
y_dim
[
1
],
x_dim
[
1
],
y_dim
[
2
]});
...
@@ -173,11 +181,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -173,11 +181,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"H. Here, H represents the last dimension of its input tensor."
)
"H. Here, H represents the last dimension of its input tensor."
)
.
AsDispensable
();
.
AsDispensable
();
AddInput
(
"QKVW"
,
"The qkv weight tensor."
);
AddInput
(
"QKVW"
,
"The qkv weight tensor."
);
AddInput
(
"QKVBias"
,
"The qkv bias tensor."
);
AddInput
(
"QKVBias"
,
"The qkv bias tensor."
)
.
AsDispensable
()
;
AddInput
(
"SrcMask"
,
"(optional) The attention mask tensor in fmha."
)
AddInput
(
"SrcMask"
,
"(optional) The attention mask tensor in fmha."
)
.
AsDispensable
();
.
AsDispensable
();
AddInput
(
"OutLinearW"
,
"The out_linear weight tensor."
);
AddInput
(
"OutLinearW"
,
"The out_linear weight tensor."
);
AddInput
(
"OutLinearBias"
,
"The out_linear bias tensor."
);
AddInput
(
"OutLinearBias"
,
"The out_linear bias tensor."
)
.
AsDispensable
()
;
AddInput
(
"Ln2Scale"
,
AddInput
(
"Ln2Scale"
,
"(optional) Scale is a 1-dimensional tensor of size "
"(optional) Scale is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor."
)
"H. Here, H represents the last dimension of its input tensor."
)
...
@@ -379,12 +387,8 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
...
@@ -379,12 +387,8 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVW"
),
"Input"
,
"QKVW"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVW"
),
"Input"
,
"QKVW"
,
"FusedAttentionGrad"
);
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVBias"
),
"Input"
,
"QKVBias"
,
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearW"
),
"Input"
,
"OutLinearW"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearW"
),
"Input"
,
"OutLinearW"
,
"FusedAttentionGrad"
);
"FusedAttentionGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearBias"
),
"Input"
,
"OutLinearBias"
,
"FusedAttentionGrad"
);
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"LnScale"
)))
{
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"LnScale"
)))
{
...
@@ -399,14 +403,17 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
...
@@ -399,14 +403,17 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"OutLinearBias"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"OutLinearBias"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"OutLinearBias"
),
ctx
->
GetInputDim
(
"OutLinearBias"
));
ctx
->
GetInputDim
(
"OutLinearBias"
));
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"OutLinearW"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"OutLinearW"
),
ctx
->
GetInputDim
(
"OutLinearW"
));
ctx
->
GetInputDim
(
"OutLinearW"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVW"
),
ctx
->
GetInputDim
(
"QKVW"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVW"
),
ctx
->
GetInputDim
(
"QKVW"
));
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"QKVBias"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVBias"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVBias"
),
ctx
->
GetInputDim
(
"QKVBias"
));
ctx
->
GetInputDim
(
"QKVBias"
));
}
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"pre_layer_norm"
)
==
true
)
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnOut"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"LnOut"
),
...
@@ -434,8 +441,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
...
@@ -434,8 +441,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
}
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVOut"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVOut"
),
ctx
->
GetInputDim
(
"QKVOut"
));
ctx
->
GetInputDim
(
"QKVOut"
));
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"QKVBiasOut"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVBiasOut"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVBiasOut"
),
ctx
->
GetInputDim
(
"QKVBiasOut"
));
ctx
->
GetInputDim
(
"QKVBiasOut"
));
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"OutLinearOut"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"OutLinearOut"
),
ctx
->
GetInputDim
(
"OutLinearOut"
));
ctx
->
GetInputDim
(
"OutLinearOut"
));
}
}
...
@@ -462,7 +471,15 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -462,7 +471,15 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
// inputs x, parameters and their grad.
// inputs x, parameters and their grad.
op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op
->
SetInput
(
"QKVW"
,
this
->
Input
(
"QKVW"
));
op
->
SetInput
(
"QKVW"
,
this
->
Input
(
"QKVW"
));
if
(
this
->
HasInput
(
"QKVBias"
))
{
op
->
SetInput
(
"QKVBias"
,
this
->
Input
(
"QKVBias"
));
op
->
SetInput
(
"QKVBias"
,
this
->
Input
(
"QKVBias"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVBias"
),
this
->
InputGrad
(
"QKVBias"
));
op
->
SetInput
(
"QKVBiasOut"
,
this
->
Output
(
"QKVBiasOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVBiasOut"
),
this
->
OutputGrad
(
"QKVBiasOut"
));
}
if
(
this
->
HasInput
(
"SrcMask"
))
{
if
(
this
->
HasInput
(
"SrcMask"
))
{
op
->
SetInput
(
"SrcMask"
,
this
->
Input
(
"SrcMask"
));
op
->
SetInput
(
"SrcMask"
,
this
->
Input
(
"SrcMask"
));
...
@@ -472,7 +489,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -472,7 +489,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
}
}
op
->
SetInput
(
"OutLinearW"
,
this
->
Input
(
"OutLinearW"
));
op
->
SetInput
(
"OutLinearW"
,
this
->
Input
(
"OutLinearW"
));
if
(
this
->
HasInput
(
"OutLinearBias"
))
{
op
->
SetInput
(
"OutLinearBias"
,
this
->
Input
(
"OutLinearBias"
));
op
->
SetInput
(
"OutLinearBias"
,
this
->
Input
(
"OutLinearBias"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"OutLinearBias"
),
this
->
InputGrad
(
"OutLinearBias"
));
}
op
->
SetAttrMap
(
this
->
Attrs
());
op
->
SetAttrMap
(
this
->
Attrs
());
bool
is_pre_layer_norm
=
bool
is_pre_layer_norm
=
...
@@ -503,10 +524,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -503,10 +524,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVW"
),
this
->
InputGrad
(
"QKVW"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVW"
),
this
->
InputGrad
(
"QKVW"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVBias"
),
this
->
InputGrad
(
"QKVBias"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"OutLinearBias"
),
this
->
InputGrad
(
"OutLinearBias"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"OutLinearW"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"OutLinearW"
),
this
->
InputGrad
(
"OutLinearW"
));
this
->
InputGrad
(
"OutLinearW"
));
...
@@ -528,7 +546,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -528,7 +546,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this
->
Output
(
"BiasDropoutResidualOut"
));
this
->
Output
(
"BiasDropoutResidualOut"
));
}
}
op
->
SetInput
(
"QKVOut"
,
this
->
Output
(
"QKVOut"
));
op
->
SetInput
(
"QKVOut"
,
this
->
Output
(
"QKVOut"
));
op
->
SetInput
(
"QKVBiasOut"
,
this
->
Output
(
"QKVBiasOut"
));
op
->
SetInput
(
"TransposeOut2"
,
this
->
Output
(
"TransposeOut2"
));
op
->
SetInput
(
"TransposeOut2"
,
this
->
Output
(
"TransposeOut2"
));
op
->
SetInput
(
"QKOut"
,
this
->
Output
(
"QKOut"
));
op
->
SetInput
(
"QKOut"
,
this
->
Output
(
"QKOut"
));
op
->
SetInput
(
"QKTVOut"
,
this
->
Output
(
"QKTVOut"
));
op
->
SetInput
(
"QKTVOut"
,
this
->
Output
(
"QKTVOut"
));
...
@@ -553,8 +571,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -553,8 +571,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
}
}
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVOut"
),
this
->
OutputGrad
(
"QKVOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVOut"
),
this
->
OutputGrad
(
"QKVOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVBiasOut"
),
this
->
OutputGrad
(
"QKVBiasOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKTVOut"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"QKTVOut"
),
this
->
OutputGrad
(
"QKTVOut"
));
this
->
OutputGrad
(
"QKTVOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"TransposeOut2"
),
op
->
SetOutput
(
framework
::
GradVarName
(
"TransposeOut2"
),
...
...
paddle/fluid/operators/fused/fused_attention_op.cu
浏览文件 @
1a8786cf
...
@@ -96,9 +96,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -96,9 +96,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto
*
x_data
=
input_x
->
data
<
T
>
();
auto
*
x_data
=
input_x
->
data
<
T
>
();
auto
*
qkv_weight_data
=
qkv_weight
->
data
<
T
>
();
auto
*
qkv_weight_data
=
qkv_weight
->
data
<
T
>
();
auto
*
qkv_bias_data
=
qkv_bias
->
data
<
T
>
();
auto
*
qkv_bias_data
=
(
qkv_bias
==
nullptr
)
?
nullptr
:
qkv_bias
->
data
<
T
>
();
auto
*
qkv_out_data
=
qkv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
qkv_out_data
=
qkv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
qkv_bias_out_data
=
qkv_bias_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
qkv_bias_out_data
=
(
qkv_bias
==
nullptr
)
?
nullptr
:
qkv_bias_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// get data ptr for FMHA.
// get data ptr for FMHA.
auto
*
transpose_out_2_data
=
auto
*
transpose_out_2_data
=
...
@@ -117,7 +119,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -117,7 +119,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// get data ptr for out_linear.
// get data ptr for out_linear.
auto
*
out_linear_weight_data
=
out_linear_weight
->
data
<
T
>
();
auto
*
out_linear_weight_data
=
out_linear_weight
->
data
<
T
>
();
auto
*
out_linear_bias_data
=
out_linear_bias
->
data
<
T
>
();
auto
*
out_linear_bias_data
=
(
out_linear_bias
==
nullptr
)
?
nullptr
:
out_linear_bias
->
data
<
T
>
();
auto
*
out_linear_out_data
=
out_linear_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
out_linear_out_data
=
out_linear_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// get data ptr for bias+dropout+residual+layernorm
// get data ptr for bias+dropout+residual+layernorm
...
@@ -139,9 +142,15 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -139,9 +142,15 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto
layer_norm_compute
=
AttnLayerNorm
<
T
>
(
ctx
.
cuda_device_context
(),
auto
layer_norm_compute
=
AttnLayerNorm
<
T
>
(
ctx
.
cuda_device_context
(),
epsilon
,
bsz_seq
,
dim_embed
);
epsilon
,
bsz_seq
,
dim_embed
);
bool
compute_bias
=
true
;
if
(
qkv_bias
==
nullptr
)
{
compute_bias
=
false
;
}
// (transA, transB, compute_bias) = (false, true, true)
// (transA, transB, compute_bias) = (false, true, true)
auto
qkv_compute
=
AttnMatMul
<
T
>
(
ctx
.
cuda_device_context
(),
false
,
true
,
auto
qkv_compute
=
bsz_seq
,
output_size
,
input_size
,
true
);
AttnMatMul
<
T
>
(
ctx
.
cuda_device_context
(),
false
,
true
,
bsz_seq
,
output_size
,
input_size
,
compute_bias
);
AttnDropoutParam
attn_dropout_param
(
AttnDropoutParam
attn_dropout_param
(
is_test_1
,
dropout_implementation_1
,
attn_dropout_rate
,
is_test_1
,
dropout_implementation_1
,
attn_dropout_rate
,
...
@@ -176,10 +185,17 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
...
@@ -176,10 +185,17 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
qkv_compute
.
ComputeForward
(
qkv_weight
,
input_x
,
qkv_bias
,
qkv_out
,
qkv_compute
.
ComputeForward
(
qkv_weight
,
input_x
,
qkv_bias
,
qkv_out
,
qkv_bias_out
);
qkv_bias_out
);
}
}
if
(
qkv_bias
==
nullptr
)
{
fmha_ref_compute
.
ComputeForward
(
*
qkv_out
,
src_mask
,
transpose_out_2
,
qk_out
,
src_mask_out
,
softmax_out
,
attn_dropout_mask_out
,
attn_dropout_out
,
qktv_out
,
fmha_out
);
}
else
{
fmha_ref_compute
.
ComputeForward
(
*
qkv_bias_out
,
src_mask
,
transpose_out_2
,
fmha_ref_compute
.
ComputeForward
(
*
qkv_bias_out
,
src_mask
,
transpose_out_2
,
qk_out
,
src_mask_out
,
softmax_out
,
qk_out
,
src_mask_out
,
softmax_out
,
attn_dropout_mask_out
,
attn_dropout_out
,
attn_dropout_mask_out
,
attn_dropout_out
,
qktv_out
,
fmha_out
);
qktv_out
,
fmha_out
);
}
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
// weight: [embed_dim, embed_dim]
...
@@ -249,9 +265,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -249,9 +265,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
out_linear_bias
=
ctx
.
Input
<
Tensor
>
(
"OutLinearBias"
);
auto
*
out_linear_bias
=
ctx
.
Input
<
Tensor
>
(
"OutLinearBias"
);
auto
*
src_mask_data
=
(
src_mask
==
nullptr
?
nullptr
:
src_mask
->
data
<
T
>
());
auto
*
src_mask_data
=
(
src_mask
==
nullptr
?
nullptr
:
src_mask
->
data
<
T
>
());
auto
*
qkv_weight_data
=
qkv_weight
->
data
<
T
>
();
auto
*
qkv_weight_data
=
qkv_weight
->
data
<
T
>
();
auto
*
qkv_bias_data
=
qkv_bias
->
data
<
T
>
();
auto
*
qkv_bias_data
=
(
qkv_bias
==
nullptr
)
?
nullptr
:
qkv_bias
->
data
<
T
>
();
auto
*
out_linear_weight_data
=
out_linear_weight
->
data
<
T
>
();
auto
*
out_linear_weight_data
=
out_linear_weight
->
data
<
T
>
();
auto
*
out_linear_bias_data
=
out_linear_bias
->
data
<
T
>
();
auto
*
out_linear_bias_data
=
(
out_linear_bias
==
nullptr
)
?
nullptr
:
out_linear_bias
->
data
<
T
>
();
// fw output
// fw output
auto
*
fmha_out
=
ctx
.
Input
<
Tensor
>
(
"FMHAOut"
);
auto
*
fmha_out
=
ctx
.
Input
<
Tensor
>
(
"FMHAOut"
);
...
@@ -299,8 +316,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -299,8 +316,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
d_bias_dropout_residual_out
=
auto
*
d_bias_dropout_residual_out
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"BiasDropoutResidualOut"
));
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"BiasDropoutResidualOut"
));
auto
*
d_x_data
=
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_x_data
=
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_out_data
=
d_qkv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the
auto
*
d_qkv_bias_out_data
=
d_qkv_bias_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// space can be reused.
auto
*
d_qkv_out_data
=
(
d_qkv_bias_out
!=
nullptr
)
?
nullptr
:
d_qkv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_bias_out_data
=
(
d_qkv_bias_out
==
nullptr
)
?
nullptr
:
d_qkv_bias_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qktv_out_data
=
d_qktv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qktv_out_data
=
d_qktv_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_transpose_out_2_data
=
auto
*
d_transpose_out_2_data
=
d_transpose_out_2
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_transpose_out_2
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
@@ -326,11 +350,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -326,11 +350,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
d_ln_2_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Ln2Bias"
));
auto
*
d_ln_2_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Ln2Bias"
));
auto
*
d_qkv_weight_data
=
d_qkv_weight
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_weight_data
=
d_qkv_weight
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_bias_data
=
d_qkv_bias
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_qkv_bias_data
=
(
d_qkv_bias
==
nullptr
)
?
nullptr
:
d_qkv_bias
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_out_linear_weight_data
=
auto
*
d_out_linear_weight_data
=
d_out_linear_weight
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_out_linear_weight
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_out_linear_bias_data
=
auto
*
d_out_linear_bias_data
=
d_out_linear_bias
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
(
d_out_linear_bias
==
nullptr
)
?
nullptr
:
d_out_linear_bias
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
auto
input_x_dims
=
input_x
->
dims
();
const
auto
input_x_dims
=
input_x
->
dims
();
const
auto
qkv_w_dims
=
qkv_weight
->
dims
();
const
auto
qkv_w_dims
=
qkv_weight
->
dims
();
...
@@ -352,12 +380,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -352,12 +380,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
bool
transA
=
false
;
bool
transA
=
false
;
bool
transB
=
true
;
bool
transB
=
true
;
bool
compute_bias
=
true
;
bool
compute_qkv_bias
=
true
;
if
(
qkv_bias
==
nullptr
)
{
compute_qkv_bias
=
false
;
}
auto
layer_norm_compute
=
AttnLayerNorm
<
T
>
(
ctx
.
cuda_device_context
(),
auto
layer_norm_compute
=
AttnLayerNorm
<
T
>
(
ctx
.
cuda_device_context
(),
epsilon
,
bsz_seq
,
dim_embed
);
epsilon
,
bsz_seq
,
dim_embed
);
auto
qkv_compute
=
auto
qkv_compute
=
AttnMatMul
<
T
>
(
ctx
.
cuda_device_context
(),
transA
,
transB
,
bsz_seq
,
AttnMatMul
<
T
>
(
ctx
.
cuda_device_context
(),
transA
,
transB
,
bsz_seq
,
output_size
,
input_size
,
compute_bias
);
output_size
,
input_size
,
compute_
qkv_
bias
);
AttnDropoutParam
attn_dropout_param
(
AttnDropoutParam
attn_dropout_param
(
is_test_1
,
dropout_implementation_1
,
attn_dropout_prob
,
is_test_1
,
dropout_implementation_1
,
attn_dropout_prob
,
is_upscale_in_train_1
,
is_fix_seed_1
,
seed_val_1
,
seed_1
);
is_upscale_in_train_1
,
is_fix_seed_1
,
seed_val_1
,
seed_1
);
...
@@ -367,7 +398,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -367,7 +398,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
output_size
=
hidden_size
;
output_size
=
hidden_size
;
transA
=
false
;
transA
=
false
;
transB
=
false
;
transB
=
false
;
compute_bias
=
false
;
bool
compute_bias
=
false
;
auto
out_linear_compute
=
auto
out_linear_compute
=
AttnMatMul
<
T
>
(
ctx
.
cuda_device_context
(),
transA
,
transB
,
bsz_seq
,
AttnMatMul
<
T
>
(
ctx
.
cuda_device_context
(),
transA
,
transB
,
bsz_seq
,
output_size
,
input_size
,
compute_bias
);
output_size
,
input_size
,
compute_bias
);
...
@@ -405,14 +436,19 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -405,14 +436,19 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_out
,
d_fmha_out
,
d_out_linear_out
,
d_fmha_out
,
d_out_linear_weight
,
nullptr
);
d_out_linear_weight
,
nullptr
);
if
(
qkv_bias
!=
nullptr
)
{
fmha_ref_compute
.
ComputeBackward
(
fmha_ref_compute
.
ComputeBackward
(
*
transpose_out_2
,
src_mask
,
*
softmax_out
,
*
attn_dropout_mask_out
,
*
transpose_out_2
,
src_mask
,
*
softmax_out
,
*
attn_dropout_mask_out
,
*
attn_dropout_out
,
*
qk_out
,
*
src_mask_out
,
*
d_fmha_out
,
d_qktv_out
,
*
attn_dropout_out
,
*
qk_out
,
*
src_mask_out
,
*
d_fmha_out
,
d_qktv_out
,
d_attn_dropout_out
,
d_softmax_out
,
d_src_mask_out
,
d_qk_out
,
d_attn_dropout_out
,
d_softmax_out
,
d_src_mask_out
,
d_qk_out
,
d_transpose_out_2
,
nullptr
,
d_qkv_bias_out
);
d_transpose_out_2
,
nullptr
,
d_qkv_bias_out
);
cudaMemcpyAsync
(
d_qkv_out_data
,
d_qkv_bias_out_data
,
}
else
{
bsz_seq
*
3
*
num_head
*
dim_head
*
sizeof
(
T
),
fmha_ref_compute
.
ComputeBackward
(
cudaMemcpyDeviceToDevice
);
*
transpose_out_2
,
src_mask
,
*
softmax_out
,
*
attn_dropout_mask_out
,
*
attn_dropout_out
,
*
qk_out
,
*
src_mask_out
,
*
d_fmha_out
,
d_qktv_out
,
d_attn_dropout_out
,
d_softmax_out
,
d_src_mask_out
,
d_qk_out
,
d_transpose_out_2
,
nullptr
,
d_qkv_out
);
}
if
(
pre_layer_norm
)
{
if
(
pre_layer_norm
)
{
auto
*
ln_mean
=
ctx
.
Input
<
Tensor
>
(
"LnMean"
);
auto
*
ln_mean
=
ctx
.
Input
<
Tensor
>
(
"LnMean"
);
...
@@ -432,15 +468,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
...
@@ -432,15 +468,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto
*
d_ln_bias_data
=
auto
*
d_ln_bias_data
=
(
d_ln_bias
==
nullptr
?
nullptr
(
d_ln_bias
==
nullptr
?
nullptr
:
d_ln_bias
->
mutable_data
<
U
>
(
ctx
.
GetPlace
()));
:
d_ln_bias
->
mutable_data
<
U
>
(
ctx
.
GetPlace
()));
if
(
qkv_bias
!=
nullptr
)
{
qkv_compute
.
ComputeBackward
(
ln_out
,
qkv_weight
,
d_qkv_bias_out
,
d_ln_out
,
qkv_compute
.
ComputeBackward
(
ln_out
,
qkv_weight
,
d_qkv_bias_out
,
d_ln_out
,
d_qkv_weight
,
d_qkv_bias
);
}
else
{
qkv_compute
.
ComputeBackward
(
ln_out
,
qkv_weight
,
d_qkv_out
,
d_ln_out
,
d_qkv_weight
,
d_qkv_bias
);
d_qkv_weight
,
d_qkv_bias
);
}
layer_norm_compute
.
ComputeBackward
(
x_data
,
d_ln_out_data
,
ln_scale_data
,
layer_norm_compute
.
ComputeBackward
(
x_data
,
d_ln_out_data
,
ln_scale_data
,
ln_mean_data
,
ln_var_data
,
d_x_data
,
ln_mean_data
,
ln_var_data
,
d_x_data
,
d_ln_scale_data
,
d_ln_bias_data
);
d_ln_scale_data
,
d_ln_bias_data
);
}
else
{
}
else
{
if
(
qkv_bias
!=
nullptr
)
{
qkv_compute
.
ComputeBackward
(
input_x
,
qkv_weight
,
d_qkv_bias_out
,
d_x
,
qkv_compute
.
ComputeBackward
(
input_x
,
qkv_weight
,
d_qkv_bias_out
,
d_x
,
d_qkv_weight
,
d_qkv_bias
);
d_qkv_weight
,
d_qkv_bias
);
}
else
{
qkv_compute
.
ComputeBackward
(
input_x
,
qkv_weight
,
d_qkv_out
,
d_x
,
d_qkv_weight
,
d_qkv_bias
);
}
}
}
// gradient accumulation
// gradient accumulation
std
::
vector
<
const
Tensor
*>
ins
;
std
::
vector
<
const
Tensor
*>
ins
;
...
...
python/paddle/fluid/tests/unittests/test_fused_attention_op.py
浏览文件 @
1a8786cf
...
@@ -168,15 +168,27 @@ class TestFusedAttentionOp(OpTest):
...
@@ -168,15 +168,27 @@ class TestFusedAttentionOp(OpTest):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
q_proj_weight
=
paddle
.
to_tensor
(
q_proj_weight
=
paddle
.
to_tensor
(
self
.
q_proj
.
weight
,
stop_gradient
=
False
)
self
.
q_proj
.
weight
,
stop_gradient
=
False
)
q_proj_bias
=
paddle
.
to_tensor
(
self
.
q_proj
.
bias
,
stop_gradient
=
False
)
k_proj_weight
=
paddle
.
to_tensor
(
k_proj_weight
=
paddle
.
to_tensor
(
self
.
k_proj
.
weight
,
stop_gradient
=
False
)
self
.
k_proj
.
weight
,
stop_gradient
=
False
)
k_proj_bias
=
paddle
.
to_tensor
(
self
.
k_proj
.
bias
,
stop_gradient
=
False
)
v_proj_weight
=
paddle
.
to_tensor
(
v_proj_weight
=
paddle
.
to_tensor
(
self
.
v_proj
.
weight
,
stop_gradient
=
False
)
self
.
v_proj
.
weight
,
stop_gradient
=
False
)
v_proj_bias
=
paddle
.
to_tensor
(
self
.
v_proj
.
bias
,
stop_gradient
=
False
)
out_linear_weight
=
paddle
.
to_tensor
(
out_linear_weight
=
paddle
.
to_tensor
(
self
.
out_proj
.
weight
,
stop_gradient
=
False
)
self
.
out_proj
.
weight
,
stop_gradient
=
False
)
if
self
.
bias_attr
is
False
:
qkv_bias_tensor
=
None
out_linear_bias
=
None
else
:
q_proj_bias
=
paddle
.
to_tensor
(
self
.
q_proj
.
bias
,
stop_gradient
=
False
)
k_proj_bias
=
paddle
.
to_tensor
(
self
.
k_proj
.
bias
,
stop_gradient
=
False
)
v_proj_bias
=
paddle
.
to_tensor
(
self
.
v_proj
.
bias
,
stop_gradient
=
False
)
qkv_bias
=
np
.
concatenate
(
(
q_proj_bias
.
numpy
(),
k_proj_bias
.
numpy
(),
v_proj_bias
.
numpy
()))
qkv_bias
=
qkv_bias
.
reshape
((
3
,
self
.
num_heads
,
self
.
head_dim
))
qkv_bias_tensor
=
paddle
.
to_tensor
(
qkv_bias
,
stop_gradient
=
False
)
out_linear_bias
=
paddle
.
to_tensor
(
out_linear_bias
=
paddle
.
to_tensor
(
self
.
out_proj
.
bias
,
stop_gradient
=
False
)
self
.
out_proj
.
bias
,
stop_gradient
=
False
)
...
@@ -193,17 +205,12 @@ class TestFusedAttentionOp(OpTest):
...
@@ -193,17 +205,12 @@ class TestFusedAttentionOp(OpTest):
qkv_weight
=
qkv_weight
.
reshape
(
qkv_weight
=
qkv_weight
.
reshape
(
(
3
,
self
.
num_heads
,
self
.
head_dim
,
self
.
embed_dim
))
(
3
,
self
.
num_heads
,
self
.
head_dim
,
self
.
embed_dim
))
qkv_bias
=
np
.
concatenate
(
(
q_proj_bias
.
numpy
(),
k_proj_bias
.
numpy
(),
v_proj_bias
.
numpy
()))
qkv_bias
=
qkv_bias
.
reshape
((
3
,
self
.
num_heads
,
self
.
head_dim
))
x
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
x
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
if
self
.
has_attn_mask
:
if
self
.
has_attn_mask
:
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
else
:
else
:
attn_mask
=
None
attn_mask
=
None
qkv_weight_tensor
=
paddle
.
to_tensor
(
qkv_weight
,
stop_gradient
=
False
)
qkv_weight_tensor
=
paddle
.
to_tensor
(
qkv_weight
,
stop_gradient
=
False
)
qkv_bias_tensor
=
paddle
.
to_tensor
(
qkv_bias
,
stop_gradient
=
False
)
epsilon
=
1e-05
epsilon
=
1e-05
ln2_epsilon
=
1e-05
ln2_epsilon
=
1e-05
...
@@ -227,6 +234,36 @@ class TestFusedAttentionOp(OpTest):
...
@@ -227,6 +234,36 @@ class TestFusedAttentionOp(OpTest):
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
class
TestFusedAttentionOpBiasIsNone
(
TestFusedAttentionOp
):
def
config
(
self
):
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
self
.
pre_layer_norm
=
False
self
.
has_attn_mask
=
True
self
.
training
=
True
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
weight_attr
=
None
self
.
bias_attr
=
False
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
self
.
query_length
,
self
.
query_length
def
test_fused_attention_op
(
self
):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
np
.
testing
.
assert_allclose
(
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-4
)
class
TestFusedAttentionOpPreLn
(
TestFusedAttentionOp
):
class
TestFusedAttentionOpPreLn
(
TestFusedAttentionOp
):
def
config
(
self
):
def
config
(
self
):
self
.
x_type
=
np
.
float32
self
.
x_type
=
np
.
float32
...
...
python/paddle/incubate/nn/functional/fused_transformer.py
浏览文件 @
1a8786cf
...
@@ -356,6 +356,9 @@ def fused_multi_head_attention(x,
...
@@ -356,6 +356,9 @@ def fused_multi_head_attention(x,
0
]
==
3
,
"The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]."
0
]
==
3
,
"The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]."
assert
qkv_weight
.
shape
[
3
]
==
x
.
shape
[
assert
qkv_weight
.
shape
[
3
]
==
x
.
shape
[
2
],
"The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim."
2
],
"The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim."
assert
qkv_weight
.
shape
[
1
]
*
qkv_weight
.
shape
[
2
]
==
qkv_weight
.
shape
[
3
],
"embed_dim must be divisible by num_heads."
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
final_out
=
_C_ops
.
fused_attention
(
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
final_out
=
_C_ops
.
fused_attention
(
x
,
pre_ln_scale
,
pre_ln_bias
,
qkv_weight
,
qkv_bias
,
attn_mask
,
x
,
pre_ln_scale
,
pre_ln_bias
,
qkv_weight
,
qkv_bias
,
attn_mask
,
linear_weight
,
linear_bias
,
ln_scale
,
ln_bias
,
'pre_layer_norm'
,
linear_weight
,
linear_bias
,
ln_scale
,
ln_bias
,
'pre_layer_norm'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录