未验证 提交 0f59d4e6 编写于 作者: 王明冬 提交者: GitHub

add compat precondition for multihead_matmul_fuse_pass_v2,v3, test=develop (#33786)

上级 7f9b8f06
...@@ -422,13 +422,335 @@ PDNode* MultiHeadMatmulPattern::operator()() { ...@@ -422,13 +422,335 @@ PDNode* MultiHeadMatmulPattern::operator()() {
return transpose2_2_out_var; return transpose2_2_out_var;
} }
static int BuildFusionV2(Graph* graph, const std::string& name_scope, PDNode* MultiHeadMatmulV3Pattern::operator()() {
Scope* scope) { std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("matmul");
// First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("matmul");
auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_op_output("matmul");
decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var;
decltype(mul0) eltadd0_out_var;
mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd0 = pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add");
eltadd0_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var =
pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2");
reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2");
transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul", "X");
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add");
eltadd_qk_out_var->AsIntermediate()->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var =
pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax");
softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops);
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_op_input("matmul");
// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("matmul");
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_op_output("matmul");
decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var;
decltype(mul1) eltadd1_out_var;
mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add");
eltadd1_b_var = pattern->NewNode(eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd1_out_var = pattern->NewNode(eltadd1_out_repr())
->assert_is_op_output("elementwise_add");
eltadd1_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_1 =
pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2");
auto* reshape2_1_out_var =
pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2");
reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input(
"matmul", "Y"); // link to matmul qk
// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul");
auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_op_output("matmul");
decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var;
decltype(mul2) eltadd2_out_var;
mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add");
eltadd2_b_var = pattern->NewNode(eltadd2_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd2_out_var = pattern->NewNode(eltadd2_out_repr())
->assert_is_op_output("elementwise_add");
eltadd2_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_2 =
pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2");
auto* reshape2_2_out_var =
pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2");
reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2");
transpose2_2_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops); // link to matmul qkv
// Q path
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
// K path
mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var});
eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
// compute q*k
matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
// V path
mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var});
eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
// compute q*k*v
matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
return transpose2_2_out_var;
}
} // namespace patterns
void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = patterns::BuildFusion(graph, name_scope_);
AddStatis(fusion_count);
}
MultiHeadMatmulV2FusePass::MultiHeadMatmulV2FusePass() {
AddOpCompat(OpCompat("mul"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(2)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
// in bias, shape is (B, S, N*H),
// in biasqk, shape is (B, H, S, S)
.IsTensor()
.End()
.AddInput("Y")
// in bias, shape is (N*H)
// in biasqk, shape is (B, H, S, S)
.IsTensor()
.End()
// in bias, shape is (B, S, N*H)
// in biasqk, shape is (B, H, S, S)
.AddOutput("Out")
.IsTensor()
.End()
// in bias, it equal to 2
// in biasqk, it equal to -1 or 0
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
// -->: (B, S, H, N) -> (B, H, S, N)
// <--: (B, H, S, N) -> (B, S, H, N)
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();
// QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S)
// QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N)
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumEQ(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y") // QK(true) QKV(false)
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
}
int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
const std::string& name_scope,
Scope* scope) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
// Create pattern. // Create pattern.
MultiHeadMatmulPattern multihead_pattern(pattern, name_scope); patterns::MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
multihead_pattern(); multihead_pattern();
// Create New OpDesc // Create New OpDesc
...@@ -580,6 +902,11 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -580,6 +902,11 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
int fusion_count{0}; int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "Op compat check in multihead_matmul_fuse_pass_v2 failed.";
return;
}
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
...@@ -714,197 +1041,141 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -714,197 +1041,141 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
return fusion_count; return fusion_count;
} }
PDNode* MultiHeadMatmulV3Pattern::operator()() { void MultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"}; FusePassBase::Init(name_scope_, graph);
auto* input0 = pattern->NewNode(input0_repr()); auto* scope = param_scope();
input0->assert_is_op_input("matmul"); PADDLE_ENFORCE_NOT_NULL(
scope,
// First path with scale platform::errors::Fatal(
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("matmul"); "During the multiheadMatmul pass, The scope should not be null."));
auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_op_output("matmul");
decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var;
decltype(mul0) eltadd0_out_var;
mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd0 = pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add");
eltadd0_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var =
pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2");
reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2");
transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul", "X");
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add");
eltadd_qk_out_var->AsIntermediate()->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var =
pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax");
softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops);
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_op_input("matmul");
// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("matmul");
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_op_output("matmul");
decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var;
decltype(mul1) eltadd1_out_var;
mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add");
eltadd1_b_var = pattern->NewNode(eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd1_out_var = pattern->NewNode(eltadd1_out_repr())
->assert_is_op_output("elementwise_add");
eltadd1_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_1 =
pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2");
auto* reshape2_1_out_var =
pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2");
reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input(
"matmul", "Y"); // link to matmul qk
// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul");
auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_op_output("matmul");
decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var;
decltype(mul2) eltadd2_out_var;
mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add");
eltadd2_b_var = pattern->NewNode(eltadd2_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd2_out_var = pattern->NewNode(eltadd2_out_repr())
->assert_is_op_output("elementwise_add");
eltadd2_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_2 =
pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2");
auto* reshape2_2_out_var =
pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2");
reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2");
transpose2_2_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops); // link to matmul qkv
// Q path
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); int fusion_count = BuildFusionV2(graph, name_scope_, scope);
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); if (fusion_count > 0) {
// K path graph->Set(kMultiheadMatmulPass, new bool(true));
mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var}); }
eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var}); AddStatis(fusion_count);
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); }
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
// compute q*k
matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
// V path
mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var});
eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
// compute q*k*v
matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
return transpose2_2_out_var; MultiHeadMatmulV3FusePass::MultiHeadMatmulV3FusePass() {
AddOpCompat(OpCompat("mul"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(2)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
// in bias, shape is (B, S, N*H),
// in biasqk, shape is (B, H, S, S)
.IsTensor()
.End()
.AddInput("Y")
// in bias, shape is (N*H)
// in biasqk, shape is (B, H, S, S)
.IsTensor()
.End()
// in bias, shape is (B, S, N*H)
// in biasqk, shape is (B, H, S, S)
.AddOutput("Out")
.IsTensor()
.End()
// in bias, it equal to 2
// in biasqk, it equal to -1 or 0
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
// -->: (B, S, H, N) -> (B, H, S, N)
// <--: (B, H, S, N) -> (B, S, H, N)
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
// QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S)
// QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N)
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>() // QK(anyvalue, will copy to new op) QKV(1.0)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y") // QK(true) QKV(false)
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
} }
static int BuildFusionV3(Graph* graph, const std::string& name_scope, int MultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
Scope* scope) { const std::string& name_scope,
Scope* scope) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
// Create pattern. // Create pattern.
MultiHeadMatmulV3Pattern multihead_pattern(pattern, name_scope); patterns::MultiHeadMatmulV3Pattern multihead_pattern(pattern, name_scope);
multihead_pattern(); multihead_pattern();
// Create New OpDesc // Create New OpDesc
...@@ -1155,30 +1426,6 @@ static int BuildFusionV3(Graph* graph, const std::string& name_scope, ...@@ -1155,30 +1426,6 @@ static int BuildFusionV3(Graph* graph, const std::string& name_scope,
return fusion_count; return fusion_count;
} }
} // namespace patterns
void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = patterns::BuildFusion(graph, name_scope_);
AddStatis(fusion_count);
}
void MultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"During the multiheadMatmul pass, The scope should not be null."));
int fusion_count = patterns::BuildFusionV2(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kMultiheadMatmulPass, new bool(true));
}
AddStatis(fusion_count);
}
void MultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const { void MultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); auto* scope = param_scope();
...@@ -1187,7 +1434,7 @@ void MultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const { ...@@ -1187,7 +1434,7 @@ void MultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
platform::errors::Fatal( platform::errors::Fatal(
"During the multiheadMatmul pass, The scope should not be null.")); "During the multiheadMatmul pass, The scope should not be null."));
int fusion_count = patterns::BuildFusionV3(graph, name_scope_, scope); int fusion_count = BuildFusionV3(graph, name_scope_, scope);
if (fusion_count > 0) { if (fusion_count > 0) {
graph->Set(kMultiheadMatmulPass, new bool(true)); graph->Set(kMultiheadMatmulPass, new bool(true));
} }
......
...@@ -18,16 +18,6 @@ ...@@ -18,16 +18,6 @@
#include <string> #include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -158,22 +148,30 @@ class MultiHeadMatmulFusePass : public FusePassBase { ...@@ -158,22 +148,30 @@ class MultiHeadMatmulFusePass : public FusePassBase {
class MultiHeadMatmulV2FusePass : public FusePassBase { class MultiHeadMatmulV2FusePass : public FusePassBase {
public: public:
virtual ~MultiHeadMatmulV2FusePass() {} MultiHeadMatmulV2FusePass();
protected: protected:
void ApplyImpl(Graph* graph) const; void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse_v2"}; const std::string name_scope_{"multihead_matmul_fuse_v2"};
private:
int BuildFusionV2(Graph* graph, const std::string& name_scope,
Scope* scope) const;
}; };
class MultiHeadMatmulV3FusePass : public FusePassBase { class MultiHeadMatmulV3FusePass : public FusePassBase {
public: public:
virtual ~MultiHeadMatmulV3FusePass() {} MultiHeadMatmulV3FusePass();
protected: protected:
void ApplyImpl(Graph* graph) const; void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse_v3"}; const std::string name_scope_{"multihead_matmul_fuse_v3"};
private:
int BuildFusionV3(Graph* graph, const std::string& name_scope,
Scope* scope) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -64,7 +64,7 @@ TEST(MultiHeadMatmulFusePass, basic) { ...@@ -64,7 +64,7 @@ TEST(MultiHeadMatmulFusePass, basic) {
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) mul -> mul_qkv // (reshape_qkv) mul -> mul_qkv
Layers layers; Layers layers;
auto* x = layers.data("x", {128, 768}); auto* x = layers.data("x", {1, 128, 768});
auto out = layers.layer_norm(x); auto out = layers.layer_norm(x);
auto* layer_out = out[0]; auto* layer_out = out[0];
...@@ -72,41 +72,41 @@ TEST(MultiHeadMatmulFusePass, basic) { ...@@ -72,41 +72,41 @@ TEST(MultiHeadMatmulFusePass, basic) {
auto* weights_1 = layers.data("weights1", {768, 768}, true); auto* weights_1 = layers.data("weights1", {768, 768}, true);
auto* weights_2 = layers.data("weights2", {768, 768}, true); auto* weights_2 = layers.data("weights2", {768, 768}, true);
auto* mul_out_0 = layers.mul(layer_out, weights_0); auto* mul_out_0 = layers.mul(layer_out, weights_0, nullptr, 2);
auto* mul_out_1 = layers.mul(layer_out, weights_1); auto* mul_out_1 = layers.mul(layer_out, weights_1, nullptr, 2);
auto* mul_out_2 = layers.mul(layer_out, weights_2); auto* mul_out_2 = layers.mul(layer_out, weights_2, nullptr, 2);
auto* b0 = layers.data("bias_0", {768}, true); auto* b0 = layers.data("bias_0", {768}, true);
auto* b1 = layers.data("bias_1", {768}, true); auto* b1 = layers.data("bias_1", {768}, true);
auto* b2 = layers.data("bias_2", {768}, true); auto* b2 = layers.data("bias_2", {768}, true);
auto* elementwise_out_0 = layers.elementwise_add(mul_out_0, b0); auto* elementwise_out_0 = layers.elementwise_add(mul_out_0, b0, nullptr, 2);
auto* elementwise_out_1 = layers.elementwise_add(mul_out_1, b1); auto* elementwise_out_1 = layers.elementwise_add(mul_out_1, b1, nullptr, 2);
auto* elementwise_out_2 = layers.elementwise_add(mul_out_2, b2); auto* elementwise_out_2 = layers.elementwise_add(mul_out_2, b2, nullptr, 2);
std::vector<int> shape = {128, 12, 64}; std::vector<int> shape = {1, 128, 12, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape); auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
auto* reshape_1 = layers.reshape2(elementwise_out_1, shape); auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true);
auto* reshape_2 = layers.reshape2(elementwise_out_2, shape); auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true);
std::vector<int> axis = {0, 2, 1, 3}; std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis); auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto* transpose_1 = layers.transpose2(reshape_1, axis); auto* transpose_1 = layers.transpose2(reshape_1, axis, true);
auto* transpose_2 = layers.transpose2(reshape_2, axis); auto* transpose_2 = layers.transpose2(reshape_2, axis, true);
auto* scale_0 = layers.scale(transpose_0, 0.125, 0, false); auto* scale_0 = layers.scale(transpose_0, 0.125, 0, false);
auto* matmul_qk = layers.matmul(scale_0, transpose_1); auto* matmul_qk = layers.matmul(scale_0, transpose_1, nullptr, false, true);
auto* bqk = layers.data("biasqk", {768}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* matmul_qkv = layers.matmul(softmax_qk, transpose_2); auto* matmul_qkv = layers.matmul(softmax_qk, transpose_2);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {128, 768}); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 768}, true);
auto* weights_l = layers.data("weightsl", {768, 768}, true); auto* weights_l = layers.data("weightsl", {768, 768}, true);
layers.mul(reshape_qkv_out, weights_l); layers.mul(reshape_qkv_out, weights_l, nullptr, 2);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
......
...@@ -293,13 +293,17 @@ struct Layers { ...@@ -293,13 +293,17 @@ struct Layers {
return outs; return outs;
} }
VarDesc* matmul(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr) { VarDesc* matmul(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr,
bool transpose_x = false, bool transpose_y = false) {
VarDesc* out = lod_tensor(unique_name()); VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp(); OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("matmul"); op->SetType("matmul");
op->SetInput("X", {x->Name()}); op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()}); op->SetInput("Y", {y->Name()});
op->SetOutput("Out", {out->Name()}); op->SetOutput("Out", {out->Name()});
op->SetAttr("transpose_X", transpose_x);
op->SetAttr("transpose_Y", transpose_y);
op->SetAttr("alpha", 1.0f);
return out; return out;
} }
......
...@@ -23,6 +23,10 @@ def { ...@@ -23,6 +23,10 @@ def {
} }
} }
extra { extra {
attrs {
name: "head_number"
type: INT
}
attrs { attrs {
name: "Scale_out" name: "Scale_out"
type: FLOAT type: FLOAT
......
...@@ -10,12 +10,12 @@ def { ...@@ -10,12 +10,12 @@ def {
name: "axis" name: "axis"
type: INT type: INT
} }
}
extra {
attrs { attrs {
name: "data_format" name: "data_format"
type: STRING type: STRING
} }
}
extra {
attrs { attrs {
name: "op_role" name: "op_role"
type: INT type: INT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册