未验证 提交 5c9299e5 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] optimize multi_encoder_xpu_pass (#50759)

上级 91992dac
......@@ -55,6 +55,7 @@ struct SingleEncoderXPUPattern : public PatternBase {
const std::string& matmul_type_1,
const std::string& matmul_type_2,
bool norm_before,
bool with_q_scale,
bool with_mask);
// declare operator node's name
......@@ -67,6 +68,7 @@ struct SingleEncoderXPUPattern : public PatternBase {
PATTERN_DECL_NODE(q_add);
PATTERN_DECL_NODE(q_reshape);
PATTERN_DECL_NODE(q_transpose);
PATTERN_DECL_NODE(q_scale);
PATTERN_DECL_NODE(k_matmul);
PATTERN_DECL_NODE(k_add);
PATTERN_DECL_NODE(k_reshape);
......@@ -102,34 +104,27 @@ struct SingleEncoderXPUPattern : public PatternBase {
PATTERN_DECL_NODE(q_add_bias);
PATTERN_DECL_NODE(q_add_out);
PATTERN_DECL_NODE(q_reshape_out);
PATTERN_DECL_NODE(q_reshape_xshape);
PATTERN_DECL_NODE(q_transpose_out);
PATTERN_DECL_NODE(q_transpose_xshape);
PATTERN_DECL_NODE(q_scale_out);
PATTERN_DECL_NODE(k_matmul_w);
PATTERN_DECL_NODE(k_matmul_out);
PATTERN_DECL_NODE(k_add_bias);
PATTERN_DECL_NODE(k_add_out);
PATTERN_DECL_NODE(k_reshape_out);
PATTERN_DECL_NODE(k_reshape_xshape);
PATTERN_DECL_NODE(k_transpose_out);
PATTERN_DECL_NODE(k_transpose_xshape);
PATTERN_DECL_NODE(v_matmul_w);
PATTERN_DECL_NODE(v_matmul_out);
PATTERN_DECL_NODE(v_add_bias);
PATTERN_DECL_NODE(v_add_out);
PATTERN_DECL_NODE(v_reshape_out);
PATTERN_DECL_NODE(v_reshape_xshape);
PATTERN_DECL_NODE(v_transpose_out);
PATTERN_DECL_NODE(v_transpose_xshape);
PATTERN_DECL_NODE(qk_matmul_out);
PATTERN_DECL_NODE(qk_add_mask);
PATTERN_DECL_NODE(qk_add_out);
PATTERN_DECL_NODE(qk_softmax_out);
PATTERN_DECL_NODE(qkv_matmul_0_out);
PATTERN_DECL_NODE(qkv_transpose_out);
PATTERN_DECL_NODE(qkv_transpose_xshape);
PATTERN_DECL_NODE(qkv_reshape_out);
PATTERN_DECL_NODE(qkv_reshape_xshape);
PATTERN_DECL_NODE(qkv_matmul_1_w);
PATTERN_DECL_NODE(qkv_matmul_1_out);
PATTERN_DECL_NODE(qkv_add_0_bias);
......@@ -162,7 +157,8 @@ struct SingleEncoderXPUPattern : public PatternBase {
std::string matmul_type_0_;
std::string matmul_type_1_;
std::string matmul_type_2_;
bool norm_before_{true};
bool norm_before_{false};
bool with_q_scale_{false};
bool with_mask_{true};
};
......@@ -174,6 +170,7 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
const std::string& matmul_type_1,
const std::string& matmul_type_2,
bool norm_before,
bool with_q_scale,
bool with_mask)
: PatternBase(pattern, name_scope, name_scope),
act_type_(act_type),
......@@ -181,30 +178,34 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
matmul_type_1_(matmul_type_1),
matmul_type_2_(matmul_type_2),
norm_before_(norm_before),
with_q_scale_(with_q_scale),
with_mask_(with_mask) {
// layer_norm 0
PDNode* ln_0_x = pattern->NewNode(ln_0_x_repr());
PDNode* ln_0_bias = nullptr;
PDNode* ln_0_scale = nullptr;
PDNode* ln_0 = nullptr;
PDNode* ln_0_out = nullptr;
PDNode* ln_0_mean = nullptr;
PDNode* ln_0_variance = nullptr;
if (norm_before_) {
ln_0_x->assert_is_op_input("layer_norm", "X")->assert_var_not_persistable();
auto* ln_0_bias = pattern->NewNode(ln_0_bias_repr())
->assert_is_op_input("layer_norm", "Bias")
->assert_is_persistable_var();
auto* ln_0_scale = pattern->NewNode(ln_0_scale_repr())
->assert_is_op_input("layer_norm", "Scale")
->assert_is_persistable_var();
auto* ln_0 = pattern->NewNode(ln_0_repr())->assert_is_op("layer_norm");
ln_0_bias = pattern->NewNode(ln_0_bias_repr())
->assert_is_op_input("layer_norm", "Bias")
->assert_is_persistable_var();
ln_0_scale = pattern->NewNode(ln_0_scale_repr())
->assert_is_op_input("layer_norm", "Scale")
->assert_is_persistable_var();
ln_0 = pattern->NewNode(ln_0_repr())->assert_is_op("layer_norm");
ln_0_out = pattern->NewNode(ln_0_out_repr())
->assert_is_op_output("layer_norm", "Y")
->assert_var_not_persistable();
auto* ln_0_mean = pattern->NewNode(ln_0_mean_repr())
->assert_is_op_output("layer_norm", "Mean")
->assert_var_not_persistable();
auto* ln_0_variance = pattern->NewNode(ln_0_variance_repr())
->assert_is_op_output("layer_norm", "Variance")
->assert_var_not_persistable();
ln_0->LinksFrom({ln_0_x, ln_0_bias, ln_0_scale})
.LinksTo({ln_0_out, ln_0_mean, ln_0_variance});
ln_0_mean = pattern->NewNode(ln_0_mean_repr())
->assert_is_op_output("layer_norm", "Mean")
->assert_var_not_persistable();
ln_0_variance = pattern->NewNode(ln_0_variance_repr())
->assert_is_op_output("layer_norm", "Variance")
->assert_var_not_persistable();
}
// q: matmul + add + reshape + transpose
......@@ -228,18 +229,22 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto* q_reshape_out = pattern->NewNode(q_reshape_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_var_not_persistable();
auto* q_reshape_xshape = pattern->NewNode(q_reshape_xshape_repr())
->assert_is_op_output("reshape2", "XShape")
->assert_var_not_persistable();
auto* q_transpose =
pattern->NewNode(q_transpose_repr())->assert_is_op("transpose2");
auto* q_transpose_out = pattern->NewNode(q_transpose_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input(matmul_type_1_, "X")
->assert_var_not_persistable();
auto* q_transpose_xshape = pattern->NewNode(q_transpose_xshape_repr())
->assert_is_op_output("transpose2", "XShape")
->assert_var_not_persistable();
PDNode* q_scale = nullptr;
PDNode* q_scale_out = nullptr;
if (with_q_scale_) {
q_scale = pattern->NewNode(q_scale_repr())->assert_is_op("scale");
q_scale_out = pattern->NewNode(q_scale_out_repr())
->assert_is_op_output("scale", "Out")
->assert_is_op_input(matmul_type_1_, "X")
->assert_var_not_persistable();
} else {
q_transpose_out->assert_is_op_input(matmul_type_1_, "X");
}
// k: matmul + add + reshape + transpose
auto k_matmul_w = pattern->NewNode(k_matmul_w_repr())
......@@ -262,18 +267,12 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto* k_reshape_out = pattern->NewNode(k_reshape_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_var_not_persistable();
auto* k_reshape_xshape = pattern->NewNode(k_reshape_xshape_repr())
->assert_is_op_output("reshape2", "XShape")
->assert_var_not_persistable();
auto* k_transpose =
pattern->NewNode(k_transpose_repr())->assert_is_op("transpose2");
auto* k_transpose_out = pattern->NewNode(k_transpose_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input(matmul_type_1_, "Y")
->assert_var_not_persistable();
auto* k_transpose_xshape = pattern->NewNode(k_transpose_xshape_repr())
->assert_is_op_output("transpose2", "XShape")
->assert_var_not_persistable();
// qk: matmul + add + softmax
auto* qk_matmul =
......@@ -281,17 +280,17 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto* qk_matmul_out = pattern->NewNode(qk_matmul_out_repr())
->assert_is_op_output(matmul_type_1_, "Out")
->assert_var_not_persistable();
PDNode* qk_add_mask = nullptr;
PDNode* qk_add = nullptr;
PDNode* qk_add_out = nullptr;
if (with_mask_) {
auto qk_add_mask = pattern->NewNode(qk_add_mask_repr())
->assert_is_op_input("elementwise_add", "Y")
->assert_var_not_persistable();
auto* qk_add =
pattern->NewNode(qk_add_repr())->assert_is_op("elementwise_add");
qk_add_mask = pattern->NewNode(qk_add_mask_repr())
->assert_is_op_input("elementwise_add", "Y")
->assert_var_not_persistable();
qk_add = pattern->NewNode(qk_add_repr())->assert_is_op("elementwise_add");
qk_add_out = pattern->NewNode(qk_add_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_var_not_persistable();
qk_add->LinksFrom({qk_matmul_out, qk_add_mask}).LinksTo({qk_add_out});
}
auto* qk_softmax =
pattern->NewNode(qk_softmax_repr())->assert_is_op("softmax");
......@@ -321,18 +320,12 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto* v_reshape_out = pattern->NewNode(v_reshape_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_var_not_persistable();
auto* v_reshape_xshape = pattern->NewNode(v_reshape_xshape_repr())
->assert_is_op_output("reshape2", "XShape")
->assert_var_not_persistable();
auto* v_transpose =
pattern->NewNode(v_transpose_repr())->assert_is_op("transpose2");
auto* v_transpose_out = pattern->NewNode(v_transpose_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input(matmul_type_2_, "Y")
->assert_var_not_persistable();
auto* v_transpose_xshape = pattern->NewNode(v_transpose_xshape_repr())
->assert_is_op_output("transpose2", "XShape")
->assert_var_not_persistable();
// qkv
auto* qkv_matmul_0 =
......@@ -345,17 +338,11 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto* qkv_transpose_out = pattern->NewNode(qkv_transpose_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_var_not_persistable();
auto* qkv_transpose_xshape = pattern->NewNode(qkv_transpose_xshape_repr())
->assert_is_op_output("transpose2", "XShape")
->assert_var_not_persistable();
auto* qkv_reshape =
pattern->NewNode(qkv_reshape_repr())->assert_is_op("reshape2");
auto* qkv_reshape_out = pattern->NewNode(qkv_reshape_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_var_not_persistable();
auto* qkv_reshape_xshape = pattern->NewNode(qkv_reshape_xshape_repr())
->assert_is_op_output("reshape2", "XShape")
->assert_var_not_persistable();
auto qkv_matmul_1_w = pattern->NewNode(qkv_matmul_1_w_repr())
->assert_is_op_input(matmul_type_0_, "Y")
->assert_is_persistable_var();
......@@ -435,61 +422,70 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto* qkv_add_4_out = pattern->NewNode(qkv_add_4_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_var_not_persistable();
PDNode* ln_2_bias = nullptr;
PDNode* ln_2_scale = nullptr;
PDNode* ln_2 = nullptr;
PDNode* ln_2_out = nullptr;
PDNode* ln_2_mean = nullptr;
PDNode* ln_2_variance = nullptr;
if (!norm_before_) {
auto* ln_2_bias = pattern->NewNode(ln_2_bias_repr())
->assert_is_op_input("layer_norm", "Bias")
->assert_is_persistable_var();
auto* ln_2_scale = pattern->NewNode(ln_2_scale_repr())
->assert_is_op_input("layer_norm", "Scale")
->assert_is_persistable_var();
auto* ln_2 = pattern->NewNode(ln_2_repr())->assert_is_op("layer_norm");
ln_2_bias = pattern->NewNode(ln_2_bias_repr())
->assert_is_op_input("layer_norm", "Bias")
->assert_is_persistable_var();
ln_2_scale = pattern->NewNode(ln_2_scale_repr())
->assert_is_op_input("layer_norm", "Scale")
->assert_is_persistable_var();
ln_2 = pattern->NewNode(ln_2_repr())->assert_is_op("layer_norm");
ln_2_out = pattern->NewNode(ln_2_out_repr())
->assert_is_op_output("layer_norm", "Y")
->assert_var_not_persistable();
auto* ln_2_mean = pattern->NewNode(ln_2_mean_repr())
->assert_is_op_output("layer_norm", "Mean")
->assert_var_not_persistable();
auto* ln_2_variance = pattern->NewNode(ln_2_variance_repr())
->assert_is_op_output("layer_norm", "Variance")
->assert_var_not_persistable();
ln_2->LinksFrom({qkv_add_4_out, ln_2_bias, ln_2_scale})
.LinksTo({ln_2_out, ln_2_mean, ln_2_variance});
ln_2_mean = pattern->NewNode(ln_2_mean_repr())
->assert_is_op_output("layer_norm", "Mean")
->assert_var_not_persistable();
ln_2_variance = pattern->NewNode(ln_2_variance_repr())
->assert_is_op_output("layer_norm", "Variance")
->assert_var_not_persistable();
}
// link nodes
PDNode* q_matmul_x = ln_0_x;
if (norm_before_) q_matmul_x = ln_0_out;
if (norm_before_) {
ln_0->LinksFrom({ln_0_x, ln_0_bias, ln_0_scale})
.LinksTo({ln_0_out, ln_0_mean, ln_0_variance});
q_matmul_x = ln_0_out;
}
q_matmul->LinksFrom({q_matmul_x, q_matmul_w}).LinksTo({q_matmul_out});
q_add->LinksFrom({q_matmul_out, q_add_bias}).LinksTo({q_add_out});
q_reshape->LinksFrom({q_add_out}).LinksTo({q_reshape_out, q_reshape_xshape});
q_transpose->LinksFrom({q_reshape_out})
.LinksTo({q_transpose_out, q_transpose_xshape});
q_reshape->LinksFrom({q_add_out}).LinksTo({q_reshape_out});
q_transpose->LinksFrom({q_reshape_out}).LinksTo({q_transpose_out});
PDNode* qk_matmul_x = q_transpose_out;
if (with_q_scale_) {
q_scale->LinksFrom({q_transpose_out}).LinksTo({q_scale_out});
qk_matmul_x = q_scale_out;
}
k_matmul->LinksFrom({q_matmul_x, k_matmul_w}).LinksTo({k_matmul_out});
k_add->LinksFrom({k_matmul_out, k_add_bias}).LinksTo({k_add_out});
k_reshape->LinksFrom({k_add_out}).LinksTo({k_reshape_out, k_reshape_xshape});
k_transpose->LinksFrom({k_reshape_out})
.LinksTo({k_transpose_out, k_transpose_xshape});
k_reshape->LinksFrom({k_add_out}).LinksTo({k_reshape_out});
k_transpose->LinksFrom({k_reshape_out}).LinksTo({k_transpose_out});
qk_matmul->LinksFrom({q_transpose_out, k_transpose_out})
.LinksTo({qk_matmul_out});
qk_matmul->LinksFrom({qk_matmul_x, k_transpose_out}).LinksTo({qk_matmul_out});
PDNode* qk_softmax_x = qk_matmul_out;
if (with_mask_) qk_softmax_x = qk_add_out;
if (with_mask_) {
qk_add->LinksFrom({qk_matmul_out, qk_add_mask}).LinksTo({qk_add_out});
qk_softmax_x = qk_add_out;
}
qk_softmax->LinksFrom({qk_softmax_x}).LinksTo({qk_softmax_out});
v_matmul->LinksFrom({q_matmul_x, v_matmul_w}).LinksTo({v_matmul_out});
v_add->LinksFrom({v_matmul_out, v_add_bias}).LinksTo({v_add_out});
v_reshape->LinksFrom({v_add_out}).LinksTo({v_reshape_out, v_reshape_xshape});
v_transpose->LinksFrom({v_reshape_out})
.LinksTo({v_transpose_out, v_transpose_xshape});
v_reshape->LinksFrom({v_add_out}).LinksTo({v_reshape_out});
v_transpose->LinksFrom({v_reshape_out}).LinksTo({v_transpose_out});
qkv_matmul_0->LinksFrom({qk_softmax_out, v_transpose_out})
.LinksTo({qkv_matmul_0_out});
qkv_transpose->LinksFrom({qkv_matmul_0_out})
.LinksTo({qkv_transpose_out, qkv_transpose_xshape});
qkv_reshape->LinksFrom({qkv_transpose_out})
.LinksTo({qkv_reshape_out, qkv_reshape_xshape});
qkv_transpose->LinksFrom({qkv_matmul_0_out}).LinksTo({qkv_transpose_out});
qkv_reshape->LinksFrom({qkv_transpose_out}).LinksTo({qkv_reshape_out});
qkv_matmul_1->LinksFrom({qkv_reshape_out, qkv_matmul_1_w})
.LinksTo({qkv_matmul_1_out});
qkv_add_0->LinksFrom({qkv_matmul_1_out, qkv_add_0_bias})
......@@ -511,6 +507,8 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
.LinksTo({qkv_add_4_out});
} else {
qkv_add_4->LinksFrom({qkv_add_3_out, ln_1_out}).LinksTo({qkv_add_4_out});
ln_2->LinksFrom({qkv_add_4_out, ln_2_bias, ln_2_scale})
.LinksTo({ln_2_out, ln_2_mean, ln_2_variance});
}
}
......@@ -614,6 +612,7 @@ class MultiEncoderXPUFusePass : public FusePassBase {
const std::string& matmul_type_1,
const std::string& matmul_type_2,
bool norm_before,
bool with_q_scale,
bool with_mask) const;
bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const;
......@@ -641,10 +640,11 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const {
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
std::vector<std::string> act_types{"gelu", "relu"};
std::vector<std::string> matmul_types_0{"mul", "matmul", "matmul_v2"};
std::vector<std::string> matmul_types_1{"matmul", "matmul_v2"};
std::vector<std::string> matmul_types_2{"matmul", "matmul_v2"};
std::vector<std::string> matmul_types_0{"matmul_v2", "matmul", "mul"};
std::vector<std::string> matmul_types_1{"matmul_v2", "matmul"};
std::vector<std::string> matmul_types_2{"matmul_v2", "matmul"};
std::vector<bool> norm_befores{true, false};
std::vector<bool> with_q_scales{true, false};
std::vector<bool> with_masks{true, false};
int single_encoder_fused_counts = 0;
int multi_encoder_fused_counts = 0;
......@@ -653,17 +653,20 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const {
for (auto matmul_type_1 : matmul_types_1) {
for (auto matmul_type_2 : matmul_types_2) {
for (auto norm_before : norm_befores) {
for (auto with_mask : with_masks) {
single_encoder_fused_counts +=
ApplySingleEncoderXPUFuse(graph,
act_type,
matmul_type_0,
matmul_type_1,
matmul_type_2,
norm_before,
with_mask);
while (ApplyMultiEncoderXPUFuse(graph)) {
multi_encoder_fused_counts++;
for (auto with_q_scale : with_q_scales) {
for (auto with_mask : with_masks) {
single_encoder_fused_counts +=
ApplySingleEncoderXPUFuse(graph,
act_type,
matmul_type_0,
matmul_type_1,
matmul_type_2,
norm_before,
with_q_scale,
with_mask);
while (ApplyMultiEncoderXPUFuse(graph)) {
multi_encoder_fused_counts++;
}
}
}
}
......@@ -734,6 +737,7 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
const std::string& matmul_type_1,
const std::string& matmul_type_2,
bool norm_before,
bool with_q_scale,
bool with_mask) const {
GraphPatternDetector gpd;
patterns::SingleEncoderXPUPattern pattern(gpd.mutable_pattern(),
......@@ -743,6 +747,7 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
matmul_type_1,
matmul_type_2,
norm_before,
with_q_scale,
with_mask);
int found_subgraph_count = 0;
......@@ -756,6 +761,7 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
GET_IR_NODE(q_add);
GET_IR_NODE(q_reshape);
GET_IR_NODE(q_transpose);
GET_IR_NODE(q_scale);
GET_IR_NODE(k_matmul);
GET_IR_NODE(k_add);
GET_IR_NODE(k_reshape);
......@@ -790,34 +796,27 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
GET_IR_NODE(q_add_bias);
GET_IR_NODE(q_add_out);
GET_IR_NODE(q_reshape_out);
GET_IR_NODE(q_reshape_xshape);
GET_IR_NODE(q_transpose_out);
GET_IR_NODE(q_transpose_xshape);
GET_IR_NODE(q_scale_out);
GET_IR_NODE(k_matmul_w);
GET_IR_NODE(k_matmul_out);
GET_IR_NODE(k_add_bias);
GET_IR_NODE(k_add_out);
GET_IR_NODE(k_reshape_out);
GET_IR_NODE(k_reshape_xshape);
GET_IR_NODE(k_transpose_out);
GET_IR_NODE(k_transpose_xshape);
GET_IR_NODE(v_matmul_w);
GET_IR_NODE(v_matmul_out);
GET_IR_NODE(v_add_bias);
GET_IR_NODE(v_add_out);
GET_IR_NODE(v_reshape_out);
GET_IR_NODE(v_reshape_xshape);
GET_IR_NODE(v_transpose_out);
GET_IR_NODE(v_transpose_xshape);
GET_IR_NODE(qk_matmul_out);
GET_IR_NODE(qk_add_mask);
GET_IR_NODE(qk_add_out);
GET_IR_NODE(qk_softmax_out);
GET_IR_NODE(qkv_matmul_0_out);
GET_IR_NODE(qkv_transpose_out);
GET_IR_NODE(qkv_transpose_xshape);
GET_IR_NODE(qkv_reshape_out);
GET_IR_NODE(qkv_reshape_xshape);
GET_IR_NODE(qkv_matmul_1_w);
GET_IR_NODE(qkv_matmul_1_out);
GET_IR_NODE(qkv_add_0_bias);
......@@ -1019,30 +1018,22 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
q_matmul_out,
q_add_out,
q_reshape_out,
q_reshape_xshape,
q_transpose_out,
q_transpose_xshape,
k_matmul_w,
k_matmul_out,
k_add_out,
k_reshape_out,
k_reshape_xshape,
k_transpose_out,
k_transpose_xshape,
v_matmul_w,
v_matmul_out,
v_add_out,
v_reshape_out,
v_reshape_xshape,
v_transpose_out,
v_transpose_xshape,
qk_matmul_out,
qk_softmax_out,
qkv_matmul_0_out,
qkv_transpose_out,
qkv_transpose_xshape,
qkv_reshape_out,
qkv_reshape_xshape,
qkv_matmul_1_out,
qkv_add_0_out,
qkv_add_1_out,
......@@ -1065,6 +1056,10 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
delete_nodes.insert(ln_2_mean);
delete_nodes.insert(ln_2_variance);
}
if (with_q_scale) {
delete_nodes.insert(q_scale);
delete_nodes.insert(q_scale_out);
}
if (with_mask) {
delete_nodes.insert(qk_add);
delete_nodes.insert(qk_add_out);
......
......@@ -517,6 +517,7 @@ void CpuPassStrategy::EraseFcMkldnnPasses() {
XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
passes_.assign({
"delete_dropout_op_pass",
"identity_scale_op_clean_pass",
"generate_sequence_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册