未验证 提交 5c9299e5 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] optimize multi_encoder_xpu_pass (#50759)

上级 91992dac
...@@ -55,6 +55,7 @@ struct SingleEncoderXPUPattern : public PatternBase { ...@@ -55,6 +55,7 @@ struct SingleEncoderXPUPattern : public PatternBase {
const std::string& matmul_type_1, const std::string& matmul_type_1,
const std::string& matmul_type_2, const std::string& matmul_type_2,
bool norm_before, bool norm_before,
bool with_q_scale,
bool with_mask); bool with_mask);
// declare operator node's name // declare operator node's name
...@@ -67,6 +68,7 @@ struct SingleEncoderXPUPattern : public PatternBase { ...@@ -67,6 +68,7 @@ struct SingleEncoderXPUPattern : public PatternBase {
PATTERN_DECL_NODE(q_add); PATTERN_DECL_NODE(q_add);
PATTERN_DECL_NODE(q_reshape); PATTERN_DECL_NODE(q_reshape);
PATTERN_DECL_NODE(q_transpose); PATTERN_DECL_NODE(q_transpose);
PATTERN_DECL_NODE(q_scale);
PATTERN_DECL_NODE(k_matmul); PATTERN_DECL_NODE(k_matmul);
PATTERN_DECL_NODE(k_add); PATTERN_DECL_NODE(k_add);
PATTERN_DECL_NODE(k_reshape); PATTERN_DECL_NODE(k_reshape);
...@@ -102,34 +104,27 @@ struct SingleEncoderXPUPattern : public PatternBase { ...@@ -102,34 +104,27 @@ struct SingleEncoderXPUPattern : public PatternBase {
PATTERN_DECL_NODE(q_add_bias); PATTERN_DECL_NODE(q_add_bias);
PATTERN_DECL_NODE(q_add_out); PATTERN_DECL_NODE(q_add_out);
PATTERN_DECL_NODE(q_reshape_out); PATTERN_DECL_NODE(q_reshape_out);
PATTERN_DECL_NODE(q_reshape_xshape);
PATTERN_DECL_NODE(q_transpose_out); PATTERN_DECL_NODE(q_transpose_out);
PATTERN_DECL_NODE(q_transpose_xshape); PATTERN_DECL_NODE(q_scale_out);
PATTERN_DECL_NODE(k_matmul_w); PATTERN_DECL_NODE(k_matmul_w);
PATTERN_DECL_NODE(k_matmul_out); PATTERN_DECL_NODE(k_matmul_out);
PATTERN_DECL_NODE(k_add_bias); PATTERN_DECL_NODE(k_add_bias);
PATTERN_DECL_NODE(k_add_out); PATTERN_DECL_NODE(k_add_out);
PATTERN_DECL_NODE(k_reshape_out); PATTERN_DECL_NODE(k_reshape_out);
PATTERN_DECL_NODE(k_reshape_xshape);
PATTERN_DECL_NODE(k_transpose_out); PATTERN_DECL_NODE(k_transpose_out);
PATTERN_DECL_NODE(k_transpose_xshape);
PATTERN_DECL_NODE(v_matmul_w); PATTERN_DECL_NODE(v_matmul_w);
PATTERN_DECL_NODE(v_matmul_out); PATTERN_DECL_NODE(v_matmul_out);
PATTERN_DECL_NODE(v_add_bias); PATTERN_DECL_NODE(v_add_bias);
PATTERN_DECL_NODE(v_add_out); PATTERN_DECL_NODE(v_add_out);
PATTERN_DECL_NODE(v_reshape_out); PATTERN_DECL_NODE(v_reshape_out);
PATTERN_DECL_NODE(v_reshape_xshape);
PATTERN_DECL_NODE(v_transpose_out); PATTERN_DECL_NODE(v_transpose_out);
PATTERN_DECL_NODE(v_transpose_xshape);
PATTERN_DECL_NODE(qk_matmul_out); PATTERN_DECL_NODE(qk_matmul_out);
PATTERN_DECL_NODE(qk_add_mask); PATTERN_DECL_NODE(qk_add_mask);
PATTERN_DECL_NODE(qk_add_out); PATTERN_DECL_NODE(qk_add_out);
PATTERN_DECL_NODE(qk_softmax_out); PATTERN_DECL_NODE(qk_softmax_out);
PATTERN_DECL_NODE(qkv_matmul_0_out); PATTERN_DECL_NODE(qkv_matmul_0_out);
PATTERN_DECL_NODE(qkv_transpose_out); PATTERN_DECL_NODE(qkv_transpose_out);
PATTERN_DECL_NODE(qkv_transpose_xshape);
PATTERN_DECL_NODE(qkv_reshape_out); PATTERN_DECL_NODE(qkv_reshape_out);
PATTERN_DECL_NODE(qkv_reshape_xshape);
PATTERN_DECL_NODE(qkv_matmul_1_w); PATTERN_DECL_NODE(qkv_matmul_1_w);
PATTERN_DECL_NODE(qkv_matmul_1_out); PATTERN_DECL_NODE(qkv_matmul_1_out);
PATTERN_DECL_NODE(qkv_add_0_bias); PATTERN_DECL_NODE(qkv_add_0_bias);
...@@ -162,7 +157,8 @@ struct SingleEncoderXPUPattern : public PatternBase { ...@@ -162,7 +157,8 @@ struct SingleEncoderXPUPattern : public PatternBase {
std::string matmul_type_0_; std::string matmul_type_0_;
std::string matmul_type_1_; std::string matmul_type_1_;
std::string matmul_type_2_; std::string matmul_type_2_;
bool norm_before_{true}; bool norm_before_{false};
bool with_q_scale_{false};
bool with_mask_{true}; bool with_mask_{true};
}; };
...@@ -174,6 +170,7 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( ...@@ -174,6 +170,7 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
const std::string& matmul_type_1, const std::string& matmul_type_1,
const std::string& matmul_type_2, const std::string& matmul_type_2,
bool norm_before, bool norm_before,
bool with_q_scale,
bool with_mask) bool with_mask)
: PatternBase(pattern, name_scope, name_scope), : PatternBase(pattern, name_scope, name_scope),
act_type_(act_type), act_type_(act_type),
...@@ -181,30 +178,34 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( ...@@ -181,30 +178,34 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
matmul_type_1_(matmul_type_1), matmul_type_1_(matmul_type_1),
matmul_type_2_(matmul_type_2), matmul_type_2_(matmul_type_2),
norm_before_(norm_before), norm_before_(norm_before),
with_q_scale_(with_q_scale),
with_mask_(with_mask) { with_mask_(with_mask) {
// layer_norm 0 // layer_norm 0
PDNode* ln_0_x = pattern->NewNode(ln_0_x_repr()); PDNode* ln_0_x = pattern->NewNode(ln_0_x_repr());
PDNode* ln_0_bias = nullptr;
PDNode* ln_0_scale = nullptr;
PDNode* ln_0 = nullptr;
PDNode* ln_0_out = nullptr; PDNode* ln_0_out = nullptr;
PDNode* ln_0_mean = nullptr;
PDNode* ln_0_variance = nullptr;
if (norm_before_) { if (norm_before_) {
ln_0_x->assert_is_op_input("layer_norm", "X")->assert_var_not_persistable(); ln_0_x->assert_is_op_input("layer_norm", "X")->assert_var_not_persistable();
auto* ln_0_bias = pattern->NewNode(ln_0_bias_repr()) ln_0_bias = pattern->NewNode(ln_0_bias_repr())
->assert_is_op_input("layer_norm", "Bias") ->assert_is_op_input("layer_norm", "Bias")
->assert_is_persistable_var(); ->assert_is_persistable_var();
auto* ln_0_scale = pattern->NewNode(ln_0_scale_repr()) ln_0_scale = pattern->NewNode(ln_0_scale_repr())
->assert_is_op_input("layer_norm", "Scale") ->assert_is_op_input("layer_norm", "Scale")
->assert_is_persistable_var(); ->assert_is_persistable_var();
auto* ln_0 = pattern->NewNode(ln_0_repr())->assert_is_op("layer_norm"); ln_0 = pattern->NewNode(ln_0_repr())->assert_is_op("layer_norm");
ln_0_out = pattern->NewNode(ln_0_out_repr()) ln_0_out = pattern->NewNode(ln_0_out_repr())
->assert_is_op_output("layer_norm", "Y") ->assert_is_op_output("layer_norm", "Y")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* ln_0_mean = pattern->NewNode(ln_0_mean_repr()) ln_0_mean = pattern->NewNode(ln_0_mean_repr())
->assert_is_op_output("layer_norm", "Mean") ->assert_is_op_output("layer_norm", "Mean")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* ln_0_variance = pattern->NewNode(ln_0_variance_repr()) ln_0_variance = pattern->NewNode(ln_0_variance_repr())
->assert_is_op_output("layer_norm", "Variance") ->assert_is_op_output("layer_norm", "Variance")
->assert_var_not_persistable(); ->assert_var_not_persistable();
ln_0->LinksFrom({ln_0_x, ln_0_bias, ln_0_scale})
.LinksTo({ln_0_out, ln_0_mean, ln_0_variance});
} }
// q: matmul + add + reshape + transpose // q: matmul + add + reshape + transpose
...@@ -228,18 +229,22 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( ...@@ -228,18 +229,22 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto* q_reshape_out = pattern->NewNode(q_reshape_out_repr()) auto* q_reshape_out = pattern->NewNode(q_reshape_out_repr())
->assert_is_op_output("reshape2", "Out") ->assert_is_op_output("reshape2", "Out")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* q_reshape_xshape = pattern->NewNode(q_reshape_xshape_repr())
->assert_is_op_output("reshape2", "XShape")
->assert_var_not_persistable();
auto* q_transpose = auto* q_transpose =
pattern->NewNode(q_transpose_repr())->assert_is_op("transpose2"); pattern->NewNode(q_transpose_repr())->assert_is_op("transpose2");
auto* q_transpose_out = pattern->NewNode(q_transpose_out_repr()) auto* q_transpose_out = pattern->NewNode(q_transpose_out_repr())
->assert_is_op_output("transpose2", "Out") ->assert_is_op_output("transpose2", "Out")
->assert_is_op_input(matmul_type_1_, "X")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* q_transpose_xshape = pattern->NewNode(q_transpose_xshape_repr()) PDNode* q_scale = nullptr;
->assert_is_op_output("transpose2", "XShape") PDNode* q_scale_out = nullptr;
->assert_var_not_persistable(); if (with_q_scale_) {
q_scale = pattern->NewNode(q_scale_repr())->assert_is_op("scale");
q_scale_out = pattern->NewNode(q_scale_out_repr())
->assert_is_op_output("scale", "Out")
->assert_is_op_input(matmul_type_1_, "X")
->assert_var_not_persistable();
} else {
q_transpose_out->assert_is_op_input(matmul_type_1_, "X");
}
// k: matmul + add + reshape + transpose // k: matmul + add + reshape + transpose
auto k_matmul_w = pattern->NewNode(k_matmul_w_repr()) auto k_matmul_w = pattern->NewNode(k_matmul_w_repr())
...@@ -262,18 +267,12 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( ...@@ -262,18 +267,12 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto* k_reshape_out = pattern->NewNode(k_reshape_out_repr()) auto* k_reshape_out = pattern->NewNode(k_reshape_out_repr())
->assert_is_op_output("reshape2", "Out") ->assert_is_op_output("reshape2", "Out")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* k_reshape_xshape = pattern->NewNode(k_reshape_xshape_repr())
->assert_is_op_output("reshape2", "XShape")
->assert_var_not_persistable();
auto* k_transpose = auto* k_transpose =
pattern->NewNode(k_transpose_repr())->assert_is_op("transpose2"); pattern->NewNode(k_transpose_repr())->assert_is_op("transpose2");
auto* k_transpose_out = pattern->NewNode(k_transpose_out_repr()) auto* k_transpose_out = pattern->NewNode(k_transpose_out_repr())
->assert_is_op_output("transpose2", "Out") ->assert_is_op_output("transpose2", "Out")
->assert_is_op_input(matmul_type_1_, "Y") ->assert_is_op_input(matmul_type_1_, "Y")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* k_transpose_xshape = pattern->NewNode(k_transpose_xshape_repr())
->assert_is_op_output("transpose2", "XShape")
->assert_var_not_persistable();
// qk: matmul + add + softmax // qk: matmul + add + softmax
auto* qk_matmul = auto* qk_matmul =
...@@ -281,17 +280,17 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( ...@@ -281,17 +280,17 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto* qk_matmul_out = pattern->NewNode(qk_matmul_out_repr()) auto* qk_matmul_out = pattern->NewNode(qk_matmul_out_repr())
->assert_is_op_output(matmul_type_1_, "Out") ->assert_is_op_output(matmul_type_1_, "Out")
->assert_var_not_persistable(); ->assert_var_not_persistable();
PDNode* qk_add_mask = nullptr;
PDNode* qk_add = nullptr;
PDNode* qk_add_out = nullptr; PDNode* qk_add_out = nullptr;
if (with_mask_) { if (with_mask_) {
auto qk_add_mask = pattern->NewNode(qk_add_mask_repr()) qk_add_mask = pattern->NewNode(qk_add_mask_repr())
->assert_is_op_input("elementwise_add", "Y") ->assert_is_op_input("elementwise_add", "Y")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* qk_add = qk_add = pattern->NewNode(qk_add_repr())->assert_is_op("elementwise_add");
pattern->NewNode(qk_add_repr())->assert_is_op("elementwise_add");
qk_add_out = pattern->NewNode(qk_add_out_repr()) qk_add_out = pattern->NewNode(qk_add_out_repr())
->assert_is_op_output("elementwise_add", "Out") ->assert_is_op_output("elementwise_add", "Out")
->assert_var_not_persistable(); ->assert_var_not_persistable();
qk_add->LinksFrom({qk_matmul_out, qk_add_mask}).LinksTo({qk_add_out});
} }
auto* qk_softmax = auto* qk_softmax =
pattern->NewNode(qk_softmax_repr())->assert_is_op("softmax"); pattern->NewNode(qk_softmax_repr())->assert_is_op("softmax");
...@@ -321,18 +320,12 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( ...@@ -321,18 +320,12 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto* v_reshape_out = pattern->NewNode(v_reshape_out_repr()) auto* v_reshape_out = pattern->NewNode(v_reshape_out_repr())
->assert_is_op_output("reshape2", "Out") ->assert_is_op_output("reshape2", "Out")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* v_reshape_xshape = pattern->NewNode(v_reshape_xshape_repr())
->assert_is_op_output("reshape2", "XShape")
->assert_var_not_persistable();
auto* v_transpose = auto* v_transpose =
pattern->NewNode(v_transpose_repr())->assert_is_op("transpose2"); pattern->NewNode(v_transpose_repr())->assert_is_op("transpose2");
auto* v_transpose_out = pattern->NewNode(v_transpose_out_repr()) auto* v_transpose_out = pattern->NewNode(v_transpose_out_repr())
->assert_is_op_output("transpose2", "Out") ->assert_is_op_output("transpose2", "Out")
->assert_is_op_input(matmul_type_2_, "Y") ->assert_is_op_input(matmul_type_2_, "Y")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* v_transpose_xshape = pattern->NewNode(v_transpose_xshape_repr())
->assert_is_op_output("transpose2", "XShape")
->assert_var_not_persistable();
// qkv // qkv
auto* qkv_matmul_0 = auto* qkv_matmul_0 =
...@@ -345,17 +338,11 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( ...@@ -345,17 +338,11 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto* qkv_transpose_out = pattern->NewNode(qkv_transpose_out_repr()) auto* qkv_transpose_out = pattern->NewNode(qkv_transpose_out_repr())
->assert_is_op_output("transpose2", "Out") ->assert_is_op_output("transpose2", "Out")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* qkv_transpose_xshape = pattern->NewNode(qkv_transpose_xshape_repr())
->assert_is_op_output("transpose2", "XShape")
->assert_var_not_persistable();
auto* qkv_reshape = auto* qkv_reshape =
pattern->NewNode(qkv_reshape_repr())->assert_is_op("reshape2"); pattern->NewNode(qkv_reshape_repr())->assert_is_op("reshape2");
auto* qkv_reshape_out = pattern->NewNode(qkv_reshape_out_repr()) auto* qkv_reshape_out = pattern->NewNode(qkv_reshape_out_repr())
->assert_is_op_output("reshape2", "Out") ->assert_is_op_output("reshape2", "Out")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* qkv_reshape_xshape = pattern->NewNode(qkv_reshape_xshape_repr())
->assert_is_op_output("reshape2", "XShape")
->assert_var_not_persistable();
auto qkv_matmul_1_w = pattern->NewNode(qkv_matmul_1_w_repr()) auto qkv_matmul_1_w = pattern->NewNode(qkv_matmul_1_w_repr())
->assert_is_op_input(matmul_type_0_, "Y") ->assert_is_op_input(matmul_type_0_, "Y")
->assert_is_persistable_var(); ->assert_is_persistable_var();
...@@ -435,61 +422,70 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( ...@@ -435,61 +422,70 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
auto* qkv_add_4_out = pattern->NewNode(qkv_add_4_out_repr()) auto* qkv_add_4_out = pattern->NewNode(qkv_add_4_out_repr())
->assert_is_op_output("elementwise_add", "Out") ->assert_is_op_output("elementwise_add", "Out")
->assert_var_not_persistable(); ->assert_var_not_persistable();
PDNode* ln_2_bias = nullptr;
PDNode* ln_2_scale = nullptr;
PDNode* ln_2 = nullptr;
PDNode* ln_2_out = nullptr; PDNode* ln_2_out = nullptr;
PDNode* ln_2_mean = nullptr;
PDNode* ln_2_variance = nullptr;
if (!norm_before_) { if (!norm_before_) {
auto* ln_2_bias = pattern->NewNode(ln_2_bias_repr()) ln_2_bias = pattern->NewNode(ln_2_bias_repr())
->assert_is_op_input("layer_norm", "Bias") ->assert_is_op_input("layer_norm", "Bias")
->assert_is_persistable_var(); ->assert_is_persistable_var();
auto* ln_2_scale = pattern->NewNode(ln_2_scale_repr()) ln_2_scale = pattern->NewNode(ln_2_scale_repr())
->assert_is_op_input("layer_norm", "Scale") ->assert_is_op_input("layer_norm", "Scale")
->assert_is_persistable_var(); ->assert_is_persistable_var();
auto* ln_2 = pattern->NewNode(ln_2_repr())->assert_is_op("layer_norm"); ln_2 = pattern->NewNode(ln_2_repr())->assert_is_op("layer_norm");
ln_2_out = pattern->NewNode(ln_2_out_repr()) ln_2_out = pattern->NewNode(ln_2_out_repr())
->assert_is_op_output("layer_norm", "Y") ->assert_is_op_output("layer_norm", "Y")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* ln_2_mean = pattern->NewNode(ln_2_mean_repr()) ln_2_mean = pattern->NewNode(ln_2_mean_repr())
->assert_is_op_output("layer_norm", "Mean") ->assert_is_op_output("layer_norm", "Mean")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* ln_2_variance = pattern->NewNode(ln_2_variance_repr()) ln_2_variance = pattern->NewNode(ln_2_variance_repr())
->assert_is_op_output("layer_norm", "Variance") ->assert_is_op_output("layer_norm", "Variance")
->assert_var_not_persistable(); ->assert_var_not_persistable();
ln_2->LinksFrom({qkv_add_4_out, ln_2_bias, ln_2_scale})
.LinksTo({ln_2_out, ln_2_mean, ln_2_variance});
} }
// link nodes // link nodes
PDNode* q_matmul_x = ln_0_x; PDNode* q_matmul_x = ln_0_x;
if (norm_before_) q_matmul_x = ln_0_out; if (norm_before_) {
ln_0->LinksFrom({ln_0_x, ln_0_bias, ln_0_scale})
.LinksTo({ln_0_out, ln_0_mean, ln_0_variance});
q_matmul_x = ln_0_out;
}
q_matmul->LinksFrom({q_matmul_x, q_matmul_w}).LinksTo({q_matmul_out}); q_matmul->LinksFrom({q_matmul_x, q_matmul_w}).LinksTo({q_matmul_out});
q_add->LinksFrom({q_matmul_out, q_add_bias}).LinksTo({q_add_out}); q_add->LinksFrom({q_matmul_out, q_add_bias}).LinksTo({q_add_out});
q_reshape->LinksFrom({q_add_out}).LinksTo({q_reshape_out, q_reshape_xshape}); q_reshape->LinksFrom({q_add_out}).LinksTo({q_reshape_out});
q_transpose->LinksFrom({q_reshape_out}) q_transpose->LinksFrom({q_reshape_out}).LinksTo({q_transpose_out});
.LinksTo({q_transpose_out, q_transpose_xshape}); PDNode* qk_matmul_x = q_transpose_out;
if (with_q_scale_) {
q_scale->LinksFrom({q_transpose_out}).LinksTo({q_scale_out});
qk_matmul_x = q_scale_out;
}
k_matmul->LinksFrom({q_matmul_x, k_matmul_w}).LinksTo({k_matmul_out}); k_matmul->LinksFrom({q_matmul_x, k_matmul_w}).LinksTo({k_matmul_out});
k_add->LinksFrom({k_matmul_out, k_add_bias}).LinksTo({k_add_out}); k_add->LinksFrom({k_matmul_out, k_add_bias}).LinksTo({k_add_out});
k_reshape->LinksFrom({k_add_out}).LinksTo({k_reshape_out, k_reshape_xshape}); k_reshape->LinksFrom({k_add_out}).LinksTo({k_reshape_out});
k_transpose->LinksFrom({k_reshape_out}) k_transpose->LinksFrom({k_reshape_out}).LinksTo({k_transpose_out});
.LinksTo({k_transpose_out, k_transpose_xshape});
qk_matmul->LinksFrom({q_transpose_out, k_transpose_out}) qk_matmul->LinksFrom({qk_matmul_x, k_transpose_out}).LinksTo({qk_matmul_out});
.LinksTo({qk_matmul_out});
PDNode* qk_softmax_x = qk_matmul_out; PDNode* qk_softmax_x = qk_matmul_out;
if (with_mask_) qk_softmax_x = qk_add_out; if (with_mask_) {
qk_add->LinksFrom({qk_matmul_out, qk_add_mask}).LinksTo({qk_add_out});
qk_softmax_x = qk_add_out;
}
qk_softmax->LinksFrom({qk_softmax_x}).LinksTo({qk_softmax_out}); qk_softmax->LinksFrom({qk_softmax_x}).LinksTo({qk_softmax_out});
v_matmul->LinksFrom({q_matmul_x, v_matmul_w}).LinksTo({v_matmul_out}); v_matmul->LinksFrom({q_matmul_x, v_matmul_w}).LinksTo({v_matmul_out});
v_add->LinksFrom({v_matmul_out, v_add_bias}).LinksTo({v_add_out}); v_add->LinksFrom({v_matmul_out, v_add_bias}).LinksTo({v_add_out});
v_reshape->LinksFrom({v_add_out}).LinksTo({v_reshape_out, v_reshape_xshape}); v_reshape->LinksFrom({v_add_out}).LinksTo({v_reshape_out});
v_transpose->LinksFrom({v_reshape_out}) v_transpose->LinksFrom({v_reshape_out}).LinksTo({v_transpose_out});
.LinksTo({v_transpose_out, v_transpose_xshape});
qkv_matmul_0->LinksFrom({qk_softmax_out, v_transpose_out}) qkv_matmul_0->LinksFrom({qk_softmax_out, v_transpose_out})
.LinksTo({qkv_matmul_0_out}); .LinksTo({qkv_matmul_0_out});
qkv_transpose->LinksFrom({qkv_matmul_0_out}) qkv_transpose->LinksFrom({qkv_matmul_0_out}).LinksTo({qkv_transpose_out});
.LinksTo({qkv_transpose_out, qkv_transpose_xshape}); qkv_reshape->LinksFrom({qkv_transpose_out}).LinksTo({qkv_reshape_out});
qkv_reshape->LinksFrom({qkv_transpose_out})
.LinksTo({qkv_reshape_out, qkv_reshape_xshape});
qkv_matmul_1->LinksFrom({qkv_reshape_out, qkv_matmul_1_w}) qkv_matmul_1->LinksFrom({qkv_reshape_out, qkv_matmul_1_w})
.LinksTo({qkv_matmul_1_out}); .LinksTo({qkv_matmul_1_out});
qkv_add_0->LinksFrom({qkv_matmul_1_out, qkv_add_0_bias}) qkv_add_0->LinksFrom({qkv_matmul_1_out, qkv_add_0_bias})
...@@ -511,6 +507,8 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( ...@@ -511,6 +507,8 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
.LinksTo({qkv_add_4_out}); .LinksTo({qkv_add_4_out});
} else { } else {
qkv_add_4->LinksFrom({qkv_add_3_out, ln_1_out}).LinksTo({qkv_add_4_out}); qkv_add_4->LinksFrom({qkv_add_3_out, ln_1_out}).LinksTo({qkv_add_4_out});
ln_2->LinksFrom({qkv_add_4_out, ln_2_bias, ln_2_scale})
.LinksTo({ln_2_out, ln_2_mean, ln_2_variance});
} }
} }
...@@ -614,6 +612,7 @@ class MultiEncoderXPUFusePass : public FusePassBase { ...@@ -614,6 +612,7 @@ class MultiEncoderXPUFusePass : public FusePassBase {
const std::string& matmul_type_1, const std::string& matmul_type_1,
const std::string& matmul_type_2, const std::string& matmul_type_2,
bool norm_before, bool norm_before,
bool with_q_scale,
bool with_mask) const; bool with_mask) const;
bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const; bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const;
...@@ -641,10 +640,11 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -641,10 +640,11 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const {
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph); Init(name_scope_, graph);
std::vector<std::string> act_types{"gelu", "relu"}; std::vector<std::string> act_types{"gelu", "relu"};
std::vector<std::string> matmul_types_0{"mul", "matmul", "matmul_v2"}; std::vector<std::string> matmul_types_0{"matmul_v2", "matmul", "mul"};
std::vector<std::string> matmul_types_1{"matmul", "matmul_v2"}; std::vector<std::string> matmul_types_1{"matmul_v2", "matmul"};
std::vector<std::string> matmul_types_2{"matmul", "matmul_v2"}; std::vector<std::string> matmul_types_2{"matmul_v2", "matmul"};
std::vector<bool> norm_befores{true, false}; std::vector<bool> norm_befores{true, false};
std::vector<bool> with_q_scales{true, false};
std::vector<bool> with_masks{true, false}; std::vector<bool> with_masks{true, false};
int single_encoder_fused_counts = 0; int single_encoder_fused_counts = 0;
int multi_encoder_fused_counts = 0; int multi_encoder_fused_counts = 0;
...@@ -653,17 +653,20 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -653,17 +653,20 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const {
for (auto matmul_type_1 : matmul_types_1) { for (auto matmul_type_1 : matmul_types_1) {
for (auto matmul_type_2 : matmul_types_2) { for (auto matmul_type_2 : matmul_types_2) {
for (auto norm_before : norm_befores) { for (auto norm_before : norm_befores) {
for (auto with_mask : with_masks) { for (auto with_q_scale : with_q_scales) {
single_encoder_fused_counts += for (auto with_mask : with_masks) {
ApplySingleEncoderXPUFuse(graph, single_encoder_fused_counts +=
act_type, ApplySingleEncoderXPUFuse(graph,
matmul_type_0, act_type,
matmul_type_1, matmul_type_0,
matmul_type_2, matmul_type_1,
norm_before, matmul_type_2,
with_mask); norm_before,
while (ApplyMultiEncoderXPUFuse(graph)) { with_q_scale,
multi_encoder_fused_counts++; with_mask);
while (ApplyMultiEncoderXPUFuse(graph)) {
multi_encoder_fused_counts++;
}
} }
} }
} }
...@@ -734,6 +737,7 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( ...@@ -734,6 +737,7 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
const std::string& matmul_type_1, const std::string& matmul_type_1,
const std::string& matmul_type_2, const std::string& matmul_type_2,
bool norm_before, bool norm_before,
bool with_q_scale,
bool with_mask) const { bool with_mask) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::SingleEncoderXPUPattern pattern(gpd.mutable_pattern(), patterns::SingleEncoderXPUPattern pattern(gpd.mutable_pattern(),
...@@ -743,6 +747,7 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( ...@@ -743,6 +747,7 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
matmul_type_1, matmul_type_1,
matmul_type_2, matmul_type_2,
norm_before, norm_before,
with_q_scale,
with_mask); with_mask);
int found_subgraph_count = 0; int found_subgraph_count = 0;
...@@ -756,6 +761,7 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( ...@@ -756,6 +761,7 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
GET_IR_NODE(q_add); GET_IR_NODE(q_add);
GET_IR_NODE(q_reshape); GET_IR_NODE(q_reshape);
GET_IR_NODE(q_transpose); GET_IR_NODE(q_transpose);
GET_IR_NODE(q_scale);
GET_IR_NODE(k_matmul); GET_IR_NODE(k_matmul);
GET_IR_NODE(k_add); GET_IR_NODE(k_add);
GET_IR_NODE(k_reshape); GET_IR_NODE(k_reshape);
...@@ -790,34 +796,27 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( ...@@ -790,34 +796,27 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
GET_IR_NODE(q_add_bias); GET_IR_NODE(q_add_bias);
GET_IR_NODE(q_add_out); GET_IR_NODE(q_add_out);
GET_IR_NODE(q_reshape_out); GET_IR_NODE(q_reshape_out);
GET_IR_NODE(q_reshape_xshape);
GET_IR_NODE(q_transpose_out); GET_IR_NODE(q_transpose_out);
GET_IR_NODE(q_transpose_xshape); GET_IR_NODE(q_scale_out);
GET_IR_NODE(k_matmul_w); GET_IR_NODE(k_matmul_w);
GET_IR_NODE(k_matmul_out); GET_IR_NODE(k_matmul_out);
GET_IR_NODE(k_add_bias); GET_IR_NODE(k_add_bias);
GET_IR_NODE(k_add_out); GET_IR_NODE(k_add_out);
GET_IR_NODE(k_reshape_out); GET_IR_NODE(k_reshape_out);
GET_IR_NODE(k_reshape_xshape);
GET_IR_NODE(k_transpose_out); GET_IR_NODE(k_transpose_out);
GET_IR_NODE(k_transpose_xshape);
GET_IR_NODE(v_matmul_w); GET_IR_NODE(v_matmul_w);
GET_IR_NODE(v_matmul_out); GET_IR_NODE(v_matmul_out);
GET_IR_NODE(v_add_bias); GET_IR_NODE(v_add_bias);
GET_IR_NODE(v_add_out); GET_IR_NODE(v_add_out);
GET_IR_NODE(v_reshape_out); GET_IR_NODE(v_reshape_out);
GET_IR_NODE(v_reshape_xshape);
GET_IR_NODE(v_transpose_out); GET_IR_NODE(v_transpose_out);
GET_IR_NODE(v_transpose_xshape);
GET_IR_NODE(qk_matmul_out); GET_IR_NODE(qk_matmul_out);
GET_IR_NODE(qk_add_mask); GET_IR_NODE(qk_add_mask);
GET_IR_NODE(qk_add_out); GET_IR_NODE(qk_add_out);
GET_IR_NODE(qk_softmax_out); GET_IR_NODE(qk_softmax_out);
GET_IR_NODE(qkv_matmul_0_out); GET_IR_NODE(qkv_matmul_0_out);
GET_IR_NODE(qkv_transpose_out); GET_IR_NODE(qkv_transpose_out);
GET_IR_NODE(qkv_transpose_xshape);
GET_IR_NODE(qkv_reshape_out); GET_IR_NODE(qkv_reshape_out);
GET_IR_NODE(qkv_reshape_xshape);
GET_IR_NODE(qkv_matmul_1_w); GET_IR_NODE(qkv_matmul_1_w);
GET_IR_NODE(qkv_matmul_1_out); GET_IR_NODE(qkv_matmul_1_out);
GET_IR_NODE(qkv_add_0_bias); GET_IR_NODE(qkv_add_0_bias);
...@@ -1019,30 +1018,22 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( ...@@ -1019,30 +1018,22 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
q_matmul_out, q_matmul_out,
q_add_out, q_add_out,
q_reshape_out, q_reshape_out,
q_reshape_xshape,
q_transpose_out, q_transpose_out,
q_transpose_xshape,
k_matmul_w, k_matmul_w,
k_matmul_out, k_matmul_out,
k_add_out, k_add_out,
k_reshape_out, k_reshape_out,
k_reshape_xshape,
k_transpose_out, k_transpose_out,
k_transpose_xshape,
v_matmul_w, v_matmul_w,
v_matmul_out, v_matmul_out,
v_add_out, v_add_out,
v_reshape_out, v_reshape_out,
v_reshape_xshape,
v_transpose_out, v_transpose_out,
v_transpose_xshape,
qk_matmul_out, qk_matmul_out,
qk_softmax_out, qk_softmax_out,
qkv_matmul_0_out, qkv_matmul_0_out,
qkv_transpose_out, qkv_transpose_out,
qkv_transpose_xshape,
qkv_reshape_out, qkv_reshape_out,
qkv_reshape_xshape,
qkv_matmul_1_out, qkv_matmul_1_out,
qkv_add_0_out, qkv_add_0_out,
qkv_add_1_out, qkv_add_1_out,
...@@ -1065,6 +1056,10 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( ...@@ -1065,6 +1056,10 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
delete_nodes.insert(ln_2_mean); delete_nodes.insert(ln_2_mean);
delete_nodes.insert(ln_2_variance); delete_nodes.insert(ln_2_variance);
} }
if (with_q_scale) {
delete_nodes.insert(q_scale);
delete_nodes.insert(q_scale_out);
}
if (with_mask) { if (with_mask) {
delete_nodes.insert(qk_add); delete_nodes.insert(qk_add);
delete_nodes.insert(qk_add_out); delete_nodes.insert(qk_add_out);
......
...@@ -517,6 +517,7 @@ void CpuPassStrategy::EraseFcMkldnnPasses() { ...@@ -517,6 +517,7 @@ void CpuPassStrategy::EraseFcMkldnnPasses() {
XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
"delete_dropout_op_pass", "delete_dropout_op_pass",
"identity_scale_op_clean_pass",
"generate_sequence_xpu_fuse_pass", "generate_sequence_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass", "multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass", "multi_encoder_xpu_slice_fuse_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册