Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5c9299e5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5c9299e5
编写于
2月 23, 2023
作者:
Z
zhupengyang
提交者:
GitHub
2月 23, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] optimize multi_encoder_xpu_pass (#50759)
上级
91992dac
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
113 addition
and
117 deletion
+113
-117
paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc
paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc
+112
-117
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
未找到文件。
paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc
浏览文件 @
5c9299e5
...
...
@@ -55,6 +55,7 @@ struct SingleEncoderXPUPattern : public PatternBase {
const
std
::
string
&
matmul_type_1
,
const
std
::
string
&
matmul_type_2
,
bool
norm_before
,
bool
with_q_scale
,
bool
with_mask
);
// declare operator node's name
...
...
@@ -67,6 +68,7 @@ struct SingleEncoderXPUPattern : public PatternBase {
PATTERN_DECL_NODE
(
q_add
);
PATTERN_DECL_NODE
(
q_reshape
);
PATTERN_DECL_NODE
(
q_transpose
);
PATTERN_DECL_NODE
(
q_scale
);
PATTERN_DECL_NODE
(
k_matmul
);
PATTERN_DECL_NODE
(
k_add
);
PATTERN_DECL_NODE
(
k_reshape
);
...
...
@@ -102,34 +104,27 @@ struct SingleEncoderXPUPattern : public PatternBase {
PATTERN_DECL_NODE
(
q_add_bias
);
PATTERN_DECL_NODE
(
q_add_out
);
PATTERN_DECL_NODE
(
q_reshape_out
);
PATTERN_DECL_NODE
(
q_reshape_xshape
);
PATTERN_DECL_NODE
(
q_transpose_out
);
PATTERN_DECL_NODE
(
q_
transpose_xshape
);
PATTERN_DECL_NODE
(
q_
scale_out
);
PATTERN_DECL_NODE
(
k_matmul_w
);
PATTERN_DECL_NODE
(
k_matmul_out
);
PATTERN_DECL_NODE
(
k_add_bias
);
PATTERN_DECL_NODE
(
k_add_out
);
PATTERN_DECL_NODE
(
k_reshape_out
);
PATTERN_DECL_NODE
(
k_reshape_xshape
);
PATTERN_DECL_NODE
(
k_transpose_out
);
PATTERN_DECL_NODE
(
k_transpose_xshape
);
PATTERN_DECL_NODE
(
v_matmul_w
);
PATTERN_DECL_NODE
(
v_matmul_out
);
PATTERN_DECL_NODE
(
v_add_bias
);
PATTERN_DECL_NODE
(
v_add_out
);
PATTERN_DECL_NODE
(
v_reshape_out
);
PATTERN_DECL_NODE
(
v_reshape_xshape
);
PATTERN_DECL_NODE
(
v_transpose_out
);
PATTERN_DECL_NODE
(
v_transpose_xshape
);
PATTERN_DECL_NODE
(
qk_matmul_out
);
PATTERN_DECL_NODE
(
qk_add_mask
);
PATTERN_DECL_NODE
(
qk_add_out
);
PATTERN_DECL_NODE
(
qk_softmax_out
);
PATTERN_DECL_NODE
(
qkv_matmul_0_out
);
PATTERN_DECL_NODE
(
qkv_transpose_out
);
PATTERN_DECL_NODE
(
qkv_transpose_xshape
);
PATTERN_DECL_NODE
(
qkv_reshape_out
);
PATTERN_DECL_NODE
(
qkv_reshape_xshape
);
PATTERN_DECL_NODE
(
qkv_matmul_1_w
);
PATTERN_DECL_NODE
(
qkv_matmul_1_out
);
PATTERN_DECL_NODE
(
qkv_add_0_bias
);
...
...
@@ -162,7 +157,8 @@ struct SingleEncoderXPUPattern : public PatternBase {
std
::
string
matmul_type_0_
;
std
::
string
matmul_type_1_
;
std
::
string
matmul_type_2_
;
bool
norm_before_
{
true
};
bool
norm_before_
{
false
};
bool
with_q_scale_
{
false
};
bool
with_mask_
{
true
};
};
...
...
@@ -174,6 +170,7 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
const
std
::
string
&
matmul_type_1
,
const
std
::
string
&
matmul_type_2
,
bool
norm_before
,
bool
with_q_scale
,
bool
with_mask
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
),
act_type_
(
act_type
),
...
...
@@ -181,30 +178,34 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
matmul_type_1_
(
matmul_type_1
),
matmul_type_2_
(
matmul_type_2
),
norm_before_
(
norm_before
),
with_q_scale_
(
with_q_scale
),
with_mask_
(
with_mask
)
{
// layer_norm 0
PDNode
*
ln_0_x
=
pattern
->
NewNode
(
ln_0_x_repr
());
PDNode
*
ln_0_bias
=
nullptr
;
PDNode
*
ln_0_scale
=
nullptr
;
PDNode
*
ln_0
=
nullptr
;
PDNode
*
ln_0_out
=
nullptr
;
PDNode
*
ln_0_mean
=
nullptr
;
PDNode
*
ln_0_variance
=
nullptr
;
if
(
norm_before_
)
{
ln_0_x
->
assert_is_op_input
(
"layer_norm"
,
"X"
)
->
assert_var_not_persistable
();
auto
*
ln_0_bias
=
pattern
->
NewNode
(
ln_0_bias_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
)
->
assert_is_persistable_var
();
auto
*
ln_0_scale
=
pattern
->
NewNode
(
ln_0_scale_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
)
->
assert_is_persistable_var
();
auto
*
ln_0
=
pattern
->
NewNode
(
ln_0_repr
())
->
assert_is_op
(
"layer_norm"
);
ln_0_bias
=
pattern
->
NewNode
(
ln_0_bias_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
)
->
assert_is_persistable_var
();
ln_0_scale
=
pattern
->
NewNode
(
ln_0_scale_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
)
->
assert_is_persistable_var
();
ln_0
=
pattern
->
NewNode
(
ln_0_repr
())
->
assert_is_op
(
"layer_norm"
);
ln_0_out
=
pattern
->
NewNode
(
ln_0_out_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_var_not_persistable
();
auto
*
ln_0_mean
=
pattern
->
NewNode
(
ln_0_mean_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
)
->
assert_var_not_persistable
();
auto
*
ln_0_variance
=
pattern
->
NewNode
(
ln_0_variance_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
)
->
assert_var_not_persistable
();
ln_0
->
LinksFrom
({
ln_0_x
,
ln_0_bias
,
ln_0_scale
})
.
LinksTo
({
ln_0_out
,
ln_0_mean
,
ln_0_variance
});
ln_0_mean
=
pattern
->
NewNode
(
ln_0_mean_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
)
->
assert_var_not_persistable
();
ln_0_variance
=
pattern
->
NewNode
(
ln_0_variance_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
)
->
assert_var_not_persistable
();
}
// q: matmul + add + reshape + transpose
...
...
@@ -228,18 +229,22 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto
*
q_reshape_out
=
pattern
->
NewNode
(
q_reshape_out_repr
())
->
assert_is_op_output
(
"reshape2"
,
"Out"
)
->
assert_var_not_persistable
();
auto
*
q_reshape_xshape
=
pattern
->
NewNode
(
q_reshape_xshape_repr
())
->
assert_is_op_output
(
"reshape2"
,
"XShape"
)
->
assert_var_not_persistable
();
auto
*
q_transpose
=
pattern
->
NewNode
(
q_transpose_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
q_transpose_out
=
pattern
->
NewNode
(
q_transpose_out_repr
())
->
assert_is_op_output
(
"transpose2"
,
"Out"
)
->
assert_is_op_input
(
matmul_type_1_
,
"X"
)
->
assert_var_not_persistable
();
auto
*
q_transpose_xshape
=
pattern
->
NewNode
(
q_transpose_xshape_repr
())
->
assert_is_op_output
(
"transpose2"
,
"XShape"
)
->
assert_var_not_persistable
();
PDNode
*
q_scale
=
nullptr
;
PDNode
*
q_scale_out
=
nullptr
;
if
(
with_q_scale_
)
{
q_scale
=
pattern
->
NewNode
(
q_scale_repr
())
->
assert_is_op
(
"scale"
);
q_scale_out
=
pattern
->
NewNode
(
q_scale_out_repr
())
->
assert_is_op_output
(
"scale"
,
"Out"
)
->
assert_is_op_input
(
matmul_type_1_
,
"X"
)
->
assert_var_not_persistable
();
}
else
{
q_transpose_out
->
assert_is_op_input
(
matmul_type_1_
,
"X"
);
}
// k: matmul + add + reshape + transpose
auto
k_matmul_w
=
pattern
->
NewNode
(
k_matmul_w_repr
())
...
...
@@ -262,18 +267,12 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto
*
k_reshape_out
=
pattern
->
NewNode
(
k_reshape_out_repr
())
->
assert_is_op_output
(
"reshape2"
,
"Out"
)
->
assert_var_not_persistable
();
auto
*
k_reshape_xshape
=
pattern
->
NewNode
(
k_reshape_xshape_repr
())
->
assert_is_op_output
(
"reshape2"
,
"XShape"
)
->
assert_var_not_persistable
();
auto
*
k_transpose
=
pattern
->
NewNode
(
k_transpose_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
k_transpose_out
=
pattern
->
NewNode
(
k_transpose_out_repr
())
->
assert_is_op_output
(
"transpose2"
,
"Out"
)
->
assert_is_op_input
(
matmul_type_1_
,
"Y"
)
->
assert_var_not_persistable
();
auto
*
k_transpose_xshape
=
pattern
->
NewNode
(
k_transpose_xshape_repr
())
->
assert_is_op_output
(
"transpose2"
,
"XShape"
)
->
assert_var_not_persistable
();
// qk: matmul + add + softmax
auto
*
qk_matmul
=
...
...
@@ -281,17 +280,17 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto
*
qk_matmul_out
=
pattern
->
NewNode
(
qk_matmul_out_repr
())
->
assert_is_op_output
(
matmul_type_1_
,
"Out"
)
->
assert_var_not_persistable
();
PDNode
*
qk_add_mask
=
nullptr
;
PDNode
*
qk_add
=
nullptr
;
PDNode
*
qk_add_out
=
nullptr
;
if
(
with_mask_
)
{
auto
qk_add_mask
=
pattern
->
NewNode
(
qk_add_mask_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
assert_var_not_persistable
();
auto
*
qk_add
=
pattern
->
NewNode
(
qk_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
qk_add_mask
=
pattern
->
NewNode
(
qk_add_mask_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
assert_var_not_persistable
();
qk_add
=
pattern
->
NewNode
(
qk_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
qk_add_out
=
pattern
->
NewNode
(
qk_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_var_not_persistable
();
qk_add
->
LinksFrom
({
qk_matmul_out
,
qk_add_mask
}).
LinksTo
({
qk_add_out
});
}
auto
*
qk_softmax
=
pattern
->
NewNode
(
qk_softmax_repr
())
->
assert_is_op
(
"softmax"
);
...
...
@@ -321,18 +320,12 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto
*
v_reshape_out
=
pattern
->
NewNode
(
v_reshape_out_repr
())
->
assert_is_op_output
(
"reshape2"
,
"Out"
)
->
assert_var_not_persistable
();
auto
*
v_reshape_xshape
=
pattern
->
NewNode
(
v_reshape_xshape_repr
())
->
assert_is_op_output
(
"reshape2"
,
"XShape"
)
->
assert_var_not_persistable
();
auto
*
v_transpose
=
pattern
->
NewNode
(
v_transpose_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
v_transpose_out
=
pattern
->
NewNode
(
v_transpose_out_repr
())
->
assert_is_op_output
(
"transpose2"
,
"Out"
)
->
assert_is_op_input
(
matmul_type_2_
,
"Y"
)
->
assert_var_not_persistable
();
auto
*
v_transpose_xshape
=
pattern
->
NewNode
(
v_transpose_xshape_repr
())
->
assert_is_op_output
(
"transpose2"
,
"XShape"
)
->
assert_var_not_persistable
();
// qkv
auto
*
qkv_matmul_0
=
...
...
@@ -345,17 +338,11 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto
*
qkv_transpose_out
=
pattern
->
NewNode
(
qkv_transpose_out_repr
())
->
assert_is_op_output
(
"transpose2"
,
"Out"
)
->
assert_var_not_persistable
();
auto
*
qkv_transpose_xshape
=
pattern
->
NewNode
(
qkv_transpose_xshape_repr
())
->
assert_is_op_output
(
"transpose2"
,
"XShape"
)
->
assert_var_not_persistable
();
auto
*
qkv_reshape
=
pattern
->
NewNode
(
qkv_reshape_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
qkv_reshape_out
=
pattern
->
NewNode
(
qkv_reshape_out_repr
())
->
assert_is_op_output
(
"reshape2"
,
"Out"
)
->
assert_var_not_persistable
();
auto
*
qkv_reshape_xshape
=
pattern
->
NewNode
(
qkv_reshape_xshape_repr
())
->
assert_is_op_output
(
"reshape2"
,
"XShape"
)
->
assert_var_not_persistable
();
auto
qkv_matmul_1_w
=
pattern
->
NewNode
(
qkv_matmul_1_w_repr
())
->
assert_is_op_input
(
matmul_type_0_
,
"Y"
)
->
assert_is_persistable_var
();
...
...
@@ -435,61 +422,70 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto
*
qkv_add_4_out
=
pattern
->
NewNode
(
qkv_add_4_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_var_not_persistable
();
PDNode
*
ln_2_bias
=
nullptr
;
PDNode
*
ln_2_scale
=
nullptr
;
PDNode
*
ln_2
=
nullptr
;
PDNode
*
ln_2_out
=
nullptr
;
PDNode
*
ln_2_mean
=
nullptr
;
PDNode
*
ln_2_variance
=
nullptr
;
if
(
!
norm_before_
)
{
auto
*
ln_2_bias
=
pattern
->
NewNode
(
ln_2_bias_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
)
->
assert_is_persistable_var
();
auto
*
ln_2_scale
=
pattern
->
NewNode
(
ln_2_scale_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
)
->
assert_is_persistable_var
();
auto
*
ln_2
=
pattern
->
NewNode
(
ln_2_repr
())
->
assert_is_op
(
"layer_norm"
);
ln_2_bias
=
pattern
->
NewNode
(
ln_2_bias_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
)
->
assert_is_persistable_var
();
ln_2_scale
=
pattern
->
NewNode
(
ln_2_scale_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
)
->
assert_is_persistable_var
();
ln_2
=
pattern
->
NewNode
(
ln_2_repr
())
->
assert_is_op
(
"layer_norm"
);
ln_2_out
=
pattern
->
NewNode
(
ln_2_out_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_var_not_persistable
();
auto
*
ln_2_mean
=
pattern
->
NewNode
(
ln_2_mean_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
)
->
assert_var_not_persistable
();
auto
*
ln_2_variance
=
pattern
->
NewNode
(
ln_2_variance_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
)
->
assert_var_not_persistable
();
ln_2
->
LinksFrom
({
qkv_add_4_out
,
ln_2_bias
,
ln_2_scale
})
.
LinksTo
({
ln_2_out
,
ln_2_mean
,
ln_2_variance
});
ln_2_mean
=
pattern
->
NewNode
(
ln_2_mean_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
)
->
assert_var_not_persistable
();
ln_2_variance
=
pattern
->
NewNode
(
ln_2_variance_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
)
->
assert_var_not_persistable
();
}
// link nodes
PDNode
*
q_matmul_x
=
ln_0_x
;
if
(
norm_before_
)
q_matmul_x
=
ln_0_out
;
if
(
norm_before_
)
{
ln_0
->
LinksFrom
({
ln_0_x
,
ln_0_bias
,
ln_0_scale
})
.
LinksTo
({
ln_0_out
,
ln_0_mean
,
ln_0_variance
});
q_matmul_x
=
ln_0_out
;
}
q_matmul
->
LinksFrom
({
q_matmul_x
,
q_matmul_w
}).
LinksTo
({
q_matmul_out
});
q_add
->
LinksFrom
({
q_matmul_out
,
q_add_bias
}).
LinksTo
({
q_add_out
});
q_reshape
->
LinksFrom
({
q_add_out
}).
LinksTo
({
q_reshape_out
,
q_reshape_xshape
});
q_transpose
->
LinksFrom
({
q_reshape_out
})
.
LinksTo
({
q_transpose_out
,
q_transpose_xshape
});
q_reshape
->
LinksFrom
({
q_add_out
}).
LinksTo
({
q_reshape_out
});
q_transpose
->
LinksFrom
({
q_reshape_out
}).
LinksTo
({
q_transpose_out
});
PDNode
*
qk_matmul_x
=
q_transpose_out
;
if
(
with_q_scale_
)
{
q_scale
->
LinksFrom
({
q_transpose_out
}).
LinksTo
({
q_scale_out
});
qk_matmul_x
=
q_scale_out
;
}
k_matmul
->
LinksFrom
({
q_matmul_x
,
k_matmul_w
}).
LinksTo
({
k_matmul_out
});
k_add
->
LinksFrom
({
k_matmul_out
,
k_add_bias
}).
LinksTo
({
k_add_out
});
k_reshape
->
LinksFrom
({
k_add_out
}).
LinksTo
({
k_reshape_out
,
k_reshape_xshape
});
k_transpose
->
LinksFrom
({
k_reshape_out
})
.
LinksTo
({
k_transpose_out
,
k_transpose_xshape
});
k_reshape
->
LinksFrom
({
k_add_out
}).
LinksTo
({
k_reshape_out
});
k_transpose
->
LinksFrom
({
k_reshape_out
}).
LinksTo
({
k_transpose_out
});
qk_matmul
->
LinksFrom
({
q_transpose_out
,
k_transpose_out
})
.
LinksTo
({
qk_matmul_out
});
qk_matmul
->
LinksFrom
({
qk_matmul_x
,
k_transpose_out
}).
LinksTo
({
qk_matmul_out
});
PDNode
*
qk_softmax_x
=
qk_matmul_out
;
if
(
with_mask_
)
qk_softmax_x
=
qk_add_out
;
if
(
with_mask_
)
{
qk_add
->
LinksFrom
({
qk_matmul_out
,
qk_add_mask
}).
LinksTo
({
qk_add_out
});
qk_softmax_x
=
qk_add_out
;
}
qk_softmax
->
LinksFrom
({
qk_softmax_x
}).
LinksTo
({
qk_softmax_out
});
v_matmul
->
LinksFrom
({
q_matmul_x
,
v_matmul_w
}).
LinksTo
({
v_matmul_out
});
v_add
->
LinksFrom
({
v_matmul_out
,
v_add_bias
}).
LinksTo
({
v_add_out
});
v_reshape
->
LinksFrom
({
v_add_out
}).
LinksTo
({
v_reshape_out
,
v_reshape_xshape
});
v_transpose
->
LinksFrom
({
v_reshape_out
})
.
LinksTo
({
v_transpose_out
,
v_transpose_xshape
});
v_reshape
->
LinksFrom
({
v_add_out
}).
LinksTo
({
v_reshape_out
});
v_transpose
->
LinksFrom
({
v_reshape_out
}).
LinksTo
({
v_transpose_out
});
qkv_matmul_0
->
LinksFrom
({
qk_softmax_out
,
v_transpose_out
})
.
LinksTo
({
qkv_matmul_0_out
});
qkv_transpose
->
LinksFrom
({
qkv_matmul_0_out
})
.
LinksTo
({
qkv_transpose_out
,
qkv_transpose_xshape
});
qkv_reshape
->
LinksFrom
({
qkv_transpose_out
})
.
LinksTo
({
qkv_reshape_out
,
qkv_reshape_xshape
});
qkv_transpose
->
LinksFrom
({
qkv_matmul_0_out
}).
LinksTo
({
qkv_transpose_out
});
qkv_reshape
->
LinksFrom
({
qkv_transpose_out
}).
LinksTo
({
qkv_reshape_out
});
qkv_matmul_1
->
LinksFrom
({
qkv_reshape_out
,
qkv_matmul_1_w
})
.
LinksTo
({
qkv_matmul_1_out
});
qkv_add_0
->
LinksFrom
({
qkv_matmul_1_out
,
qkv_add_0_bias
})
...
...
@@ -511,6 +507,8 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
.
LinksTo
({
qkv_add_4_out
});
}
else
{
qkv_add_4
->
LinksFrom
({
qkv_add_3_out
,
ln_1_out
}).
LinksTo
({
qkv_add_4_out
});
ln_2
->
LinksFrom
({
qkv_add_4_out
,
ln_2_bias
,
ln_2_scale
})
.
LinksTo
({
ln_2_out
,
ln_2_mean
,
ln_2_variance
});
}
}
...
...
@@ -614,6 +612,7 @@ class MultiEncoderXPUFusePass : public FusePassBase {
const
std
::
string
&
matmul_type_1
,
const
std
::
string
&
matmul_type_2
,
bool
norm_before
,
bool
with_q_scale
,
bool
with_mask
)
const
;
bool
ApplyMultiEncoderXPUFuse
(
ir
::
Graph
*
graph
)
const
;
...
...
@@ -641,10 +640,11 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const {
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
std
::
vector
<
std
::
string
>
act_types
{
"gelu"
,
"relu"
};
std
::
vector
<
std
::
string
>
matmul_types_0
{
"m
ul"
,
"matmul"
,
"matmul_v2
"
};
std
::
vector
<
std
::
string
>
matmul_types_1
{
"matmul
"
,
"matmul_v2
"
};
std
::
vector
<
std
::
string
>
matmul_types_2
{
"matmul
"
,
"matmul_v2
"
};
std
::
vector
<
std
::
string
>
matmul_types_0
{
"m
atmul_v2"
,
"matmul"
,
"mul
"
};
std
::
vector
<
std
::
string
>
matmul_types_1
{
"matmul
_v2"
,
"matmul
"
};
std
::
vector
<
std
::
string
>
matmul_types_2
{
"matmul
_v2"
,
"matmul
"
};
std
::
vector
<
bool
>
norm_befores
{
true
,
false
};
std
::
vector
<
bool
>
with_q_scales
{
true
,
false
};
std
::
vector
<
bool
>
with_masks
{
true
,
false
};
int
single_encoder_fused_counts
=
0
;
int
multi_encoder_fused_counts
=
0
;
...
...
@@ -653,17 +653,20 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const {
for
(
auto
matmul_type_1
:
matmul_types_1
)
{
for
(
auto
matmul_type_2
:
matmul_types_2
)
{
for
(
auto
norm_before
:
norm_befores
)
{
for
(
auto
with_mask
:
with_masks
)
{
single_encoder_fused_counts
+=
ApplySingleEncoderXPUFuse
(
graph
,
act_type
,
matmul_type_0
,
matmul_type_1
,
matmul_type_2
,
norm_before
,
with_mask
);
while
(
ApplyMultiEncoderXPUFuse
(
graph
))
{
multi_encoder_fused_counts
++
;
for
(
auto
with_q_scale
:
with_q_scales
)
{
for
(
auto
with_mask
:
with_masks
)
{
single_encoder_fused_counts
+=
ApplySingleEncoderXPUFuse
(
graph
,
act_type
,
matmul_type_0
,
matmul_type_1
,
matmul_type_2
,
norm_before
,
with_q_scale
,
with_mask
);
while
(
ApplyMultiEncoderXPUFuse
(
graph
))
{
multi_encoder_fused_counts
++
;
}
}
}
}
...
...
@@ -734,6 +737,7 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
const
std
::
string
&
matmul_type_1
,
const
std
::
string
&
matmul_type_2
,
bool
norm_before
,
bool
with_q_scale
,
bool
with_mask
)
const
{
GraphPatternDetector
gpd
;
patterns
::
SingleEncoderXPUPattern
pattern
(
gpd
.
mutable_pattern
(),
...
...
@@ -743,6 +747,7 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
matmul_type_1
,
matmul_type_2
,
norm_before
,
with_q_scale
,
with_mask
);
int
found_subgraph_count
=
0
;
...
...
@@ -756,6 +761,7 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
GET_IR_NODE
(
q_add
);
GET_IR_NODE
(
q_reshape
);
GET_IR_NODE
(
q_transpose
);
GET_IR_NODE
(
q_scale
);
GET_IR_NODE
(
k_matmul
);
GET_IR_NODE
(
k_add
);
GET_IR_NODE
(
k_reshape
);
...
...
@@ -790,34 +796,27 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
GET_IR_NODE
(
q_add_bias
);
GET_IR_NODE
(
q_add_out
);
GET_IR_NODE
(
q_reshape_out
);
GET_IR_NODE
(
q_reshape_xshape
);
GET_IR_NODE
(
q_transpose_out
);
GET_IR_NODE
(
q_
transpose_xshape
);
GET_IR_NODE
(
q_
scale_out
);
GET_IR_NODE
(
k_matmul_w
);
GET_IR_NODE
(
k_matmul_out
);
GET_IR_NODE
(
k_add_bias
);
GET_IR_NODE
(
k_add_out
);
GET_IR_NODE
(
k_reshape_out
);
GET_IR_NODE
(
k_reshape_xshape
);
GET_IR_NODE
(
k_transpose_out
);
GET_IR_NODE
(
k_transpose_xshape
);
GET_IR_NODE
(
v_matmul_w
);
GET_IR_NODE
(
v_matmul_out
);
GET_IR_NODE
(
v_add_bias
);
GET_IR_NODE
(
v_add_out
);
GET_IR_NODE
(
v_reshape_out
);
GET_IR_NODE
(
v_reshape_xshape
);
GET_IR_NODE
(
v_transpose_out
);
GET_IR_NODE
(
v_transpose_xshape
);
GET_IR_NODE
(
qk_matmul_out
);
GET_IR_NODE
(
qk_add_mask
);
GET_IR_NODE
(
qk_add_out
);
GET_IR_NODE
(
qk_softmax_out
);
GET_IR_NODE
(
qkv_matmul_0_out
);
GET_IR_NODE
(
qkv_transpose_out
);
GET_IR_NODE
(
qkv_transpose_xshape
);
GET_IR_NODE
(
qkv_reshape_out
);
GET_IR_NODE
(
qkv_reshape_xshape
);
GET_IR_NODE
(
qkv_matmul_1_w
);
GET_IR_NODE
(
qkv_matmul_1_out
);
GET_IR_NODE
(
qkv_add_0_bias
);
...
...
@@ -1019,30 +1018,22 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
q_matmul_out
,
q_add_out
,
q_reshape_out
,
q_reshape_xshape
,
q_transpose_out
,
q_transpose_xshape
,
k_matmul_w
,
k_matmul_out
,
k_add_out
,
k_reshape_out
,
k_reshape_xshape
,
k_transpose_out
,
k_transpose_xshape
,
v_matmul_w
,
v_matmul_out
,
v_add_out
,
v_reshape_out
,
v_reshape_xshape
,
v_transpose_out
,
v_transpose_xshape
,
qk_matmul_out
,
qk_softmax_out
,
qkv_matmul_0_out
,
qkv_transpose_out
,
qkv_transpose_xshape
,
qkv_reshape_out
,
qkv_reshape_xshape
,
qkv_matmul_1_out
,
qkv_add_0_out
,
qkv_add_1_out
,
...
...
@@ -1065,6 +1056,10 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
delete_nodes
.
insert
(
ln_2_mean
);
delete_nodes
.
insert
(
ln_2_variance
);
}
if
(
with_q_scale
)
{
delete_nodes
.
insert
(
q_scale
);
delete_nodes
.
insert
(
q_scale_out
);
}
if
(
with_mask
)
{
delete_nodes
.
insert
(
qk_add
);
delete_nodes
.
insert
(
qk_add_out
);
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
5c9299e5
...
...
@@ -517,6 +517,7 @@ void CpuPassStrategy::EraseFcMkldnnPasses() {
XpuPassStrategy
::
XpuPassStrategy
()
:
PassStrategy
({})
{
passes_
.
assign
({
"delete_dropout_op_pass"
,
"identity_scale_op_clean_pass"
,
"generate_sequence_xpu_fuse_pass"
,
"multi_encoder_xpu_fuse_pass"
,
"multi_encoder_xpu_slice_fuse_pass"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录