Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a365024c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a365024c
编写于
12月 01, 2022
作者:
M
minghaoBD
提交者:
GitHub
12月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fuse-mt passes compatible with structured pruning (#48585)
* fuse-mt passes compatible with structured pruning
上级
310f4320
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
40 addition
and
53 deletion
+40
-53
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
...luid/framework/ir/fused_multi_transformer_encoder_pass.cc
+40
-32
paddle/fluid/operators/fused/fused_multi_transformer_op.cc
paddle/fluid/operators/fused/fused_multi_transformer_op.cc
+0
-21
未找到文件。
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
浏览文件 @
a365024c
...
...
@@ -1325,17 +1325,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_output
)
{
auto
reshape_desc
=
reshape2_0
->
Op
();
int
num_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
2
);
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
);
auto
*
layer_norm_bias_tensor
=
scope
->
FindVar
(
layer_norm_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
int
dim_embed
=
layer_norm_bias_tensor
->
dims
()[
0
];
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
ffn_matmul_0_op
=
ffn_matmul0
->
Op
();
...
...
@@ -1364,6 +1353,20 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
auto
*
bv_tensor
=
scope
->
FindVar
(
eltadd2_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
// num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0]
// 2. calculate num_head according to wq_tensor.shape[1] and dim_head
auto
reshape_desc
=
reshape2_0
->
Op
();
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
);
auto
*
layer_norm_bias_tensor
=
scope
->
FindVar
(
layer_norm_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
int
dim_embed
=
layer_norm_bias_tensor
->
dims
()[
0
];
int
num_head
=
wq_tensor
->
dims
()[
1
]
/
dim_head
;
QKVWeightsBiasProcess
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
...
...
@@ -2195,18 +2198,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_output
)
{
auto
reshape_desc
=
reshape2_0
->
Op
();
int
num_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
2
);
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
)
/
3
;
// 3 for qkv
auto
*
layer_norm_bias_tensor
=
scope
->
FindVar
(
layer_norm_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
int
dim_embed
=
layer_norm_bias_tensor
->
dims
()[
0
];
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
ffn_matmul_0_op
=
ffn_matmul0
->
Op
();
...
...
@@ -2226,6 +2217,21 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
auto
*
qkv_b_tensor
=
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
// num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0]
// 2. calculate num_head according to wqkv_tensor.shape[1]/3 and dim_head
auto
reshape_desc
=
reshape2_0
->
Op
();
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
)
/
3
;
// 3 for qkv
auto
*
layer_norm_bias_tensor
=
scope
->
FindVar
(
layer_norm_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
int
dim_embed
=
layer_norm_bias_tensor
->
dims
()[
0
];
int
num_head
=
qkv_w_tensor
->
dims
()[
1
]
/
3
/
dim_head
;
QKVWeightsBiasProcessFuseQKV
(
qkv_w_tensor
,
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
...
...
@@ -2995,15 +3001,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_output
)
{
auto
reshape_desc
=
reshape2_0
->
Op
();
int
num_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
2
);
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
)
/
3
;
// 3 for qkv
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
ffn_matmul_0_op
=
ffn_matmul0
->
Op
();
...
...
@@ -3023,9 +3020,20 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
auto
*
qkv_b_tensor
=
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
// num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0]
// 2. calculate num_head according to wqkv_tensor.shape[1]/3 and dim_head
auto
*
layer_norm_bias_tensor
=
scope
->
FindVar
(
layer_norm_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
int
dim_embed
=
layer_norm_bias_tensor
->
dims
()[
0
];
auto
reshape_desc
=
reshape2_0
->
Op
();
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
)
/
3
;
// 3 for qkv
int
num_head
=
qkv_w_tensor
->
dims
()[
1
]
/
3
/
dim_head
;
QKVWeightsBiasProcessFuseQKV
(
qkv_w_tensor
,
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
...
...
paddle/fluid/operators/fused/fused_multi_transformer_op.cc
浏览文件 @
a365024c
...
...
@@ -93,27 +93,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel {
x_dim
,
y_dim
));
if
(
ctx
->
Attrs
().
Get
<
int
>
(
"ring_id"
)
==
-
1
)
{
if
(
trans_qkvw
)
{
PADDLE_ENFORCE_EQ
(
y_dim
[
1
]
*
y_dim
[
2
],
y_dim
[
3
],
platform
::
errors
::
InvalidArgument
(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"
));
}
else
{
PADDLE_ENFORCE_EQ
(
y_dim
[
2
]
*
y_dim
[
3
],
y_dim
[
0
],
platform
::
errors
::
InvalidArgument
(
"The dimensions of qkv_weight must be 4"
"(dim_embed, 3, num_head, dim_head),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"
));
}
}
if
(
ctx
->
HasInputs
(
"CacheKV"
))
{
// [2, batch_size, num_head, max_seq_len, head_size]
const
auto
&
c_dims
=
ctx
->
GetInputsDim
(
"CacheKV"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录