未验证 提交 0f59d4e6 编写于 作者: 王明冬 提交者: GitHub

add compat precondition for multihead_matmul_fuse_pass_v2,v3, test=develop (#33786)

上级 7f9b8f06
......@@ -422,13 +422,335 @@ PDNode* MultiHeadMatmulPattern::operator()() {
return transpose2_2_out_var;
}
static int BuildFusionV2(Graph* graph, const std::string& name_scope,
Scope* scope) {
PDNode* MultiHeadMatmulV3Pattern::operator()() {
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("matmul");
// First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("matmul");
auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_op_output("matmul");
decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var;
decltype(mul0) eltadd0_out_var;
mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd0 = pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add");
eltadd0_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var =
pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2");
reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2");
transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul", "X");
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add");
eltadd_qk_out_var->AsIntermediate()->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var =
pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax");
softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops);
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_op_input("matmul");
// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("matmul");
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_op_output("matmul");
decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var;
decltype(mul1) eltadd1_out_var;
mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add");
eltadd1_b_var = pattern->NewNode(eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd1_out_var = pattern->NewNode(eltadd1_out_repr())
->assert_is_op_output("elementwise_add");
eltadd1_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_1 =
pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2");
auto* reshape2_1_out_var =
pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2");
reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input(
"matmul", "Y"); // link to matmul qk
// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul");
auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_op_output("matmul");
decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var;
decltype(mul2) eltadd2_out_var;
mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add");
eltadd2_b_var = pattern->NewNode(eltadd2_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd2_out_var = pattern->NewNode(eltadd2_out_repr())
->assert_is_op_output("elementwise_add");
eltadd2_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_2 =
pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2");
auto* reshape2_2_out_var =
pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2");
reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2");
transpose2_2_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops); // link to matmul qkv
// Q path
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
// K path
mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var});
eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
// compute q*k
matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
// V path
mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var});
eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
// compute q*k*v
matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
return transpose2_2_out_var;
}
} // namespace patterns
void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = patterns::BuildFusion(graph, name_scope_);
AddStatis(fusion_count);
}
MultiHeadMatmulV2FusePass::MultiHeadMatmulV2FusePass() {
AddOpCompat(OpCompat("mul"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(2)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
// in bias, shape is (B, S, N*H),
// in biasqk, shape is (B, H, S, S)
.IsTensor()
.End()
.AddInput("Y")
// in bias, shape is (N*H)
// in biasqk, shape is (B, H, S, S)
.IsTensor()
.End()
// in bias, shape is (B, S, N*H)
// in biasqk, shape is (B, H, S, S)
.AddOutput("Out")
.IsTensor()
.End()
// in bias, it equal to 2
// in biasqk, it equal to -1 or 0
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
// -->: (B, S, H, N) -> (B, H, S, N)
// <--: (B, H, S, N) -> (B, S, H, N)
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();
// QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S)
// QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N)
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumEQ(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y") // QK(true) QKV(false)
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
}
int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
const std::string& name_scope,
Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
patterns::MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
multihead_pattern();
// Create New OpDesc
......@@ -580,6 +902,11 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "Op compat check in multihead_matmul_fuse_pass_v2 failed.";
return;
}
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
......@@ -714,197 +1041,141 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
return fusion_count;
}
PDNode* MultiHeadMatmulV3Pattern::operator()() {
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("matmul");
// First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("matmul");
auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_op_output("matmul");
decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var;
decltype(mul0) eltadd0_out_var;
mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd0 = pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add");
eltadd0_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var =
pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2");
reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2");
transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul", "X");
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add");
eltadd_qk_out_var->AsIntermediate()->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var =
pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax");
softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops);
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_op_input("matmul");
// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("matmul");
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_op_output("matmul");
decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var;
decltype(mul1) eltadd1_out_var;
mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add");
eltadd1_b_var = pattern->NewNode(eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd1_out_var = pattern->NewNode(eltadd1_out_repr())
->assert_is_op_output("elementwise_add");
eltadd1_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_1 =
pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2");
auto* reshape2_1_out_var =
pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2");
reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input(
"matmul", "Y"); // link to matmul qk
// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul");
auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_op_output("matmul");
decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var;
decltype(mul2) eltadd2_out_var;
mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add");
eltadd2_b_var = pattern->NewNode(eltadd2_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd2_out_var = pattern->NewNode(eltadd2_out_repr())
->assert_is_op_output("elementwise_add");
eltadd2_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_2 =
pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2");
auto* reshape2_2_out_var =
pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2");
reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2");
transpose2_2_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops); // link to matmul qkv
// Q path
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var});
void MultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"During the multiheadMatmul pass, The scope should not be null."));
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
// K path
mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var});
eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
// compute q*k
matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
// V path
mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var});
eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
// compute q*k*v
matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
int fusion_count = BuildFusionV2(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kMultiheadMatmulPass, new bool(true));
}
AddStatis(fusion_count);
}
return transpose2_2_out_var;
MultiHeadMatmulV3FusePass::MultiHeadMatmulV3FusePass() {
AddOpCompat(OpCompat("mul"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(2)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
// in bias, shape is (B, S, N*H),
// in biasqk, shape is (B, H, S, S)
.IsTensor()
.End()
.AddInput("Y")
// in bias, shape is (N*H)
// in biasqk, shape is (B, H, S, S)
.IsTensor()
.End()
// in bias, shape is (B, S, N*H)
// in biasqk, shape is (B, H, S, S)
.AddOutput("Out")
.IsTensor()
.End()
// in bias, it equal to 2
// in biasqk, it equal to -1 or 0
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
// -->: (B, S, H, N) -> (B, H, S, N)
// <--: (B, H, S, N) -> (B, S, H, N)
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
// QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S)
// QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N)
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>() // QK(anyvalue, will copy to new op) QKV(1.0)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y") // QK(true) QKV(false)
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
}
static int BuildFusionV3(Graph* graph, const std::string& name_scope,
Scope* scope) {
int MultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
const std::string& name_scope,
Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
MultiHeadMatmulV3Pattern multihead_pattern(pattern, name_scope);
patterns::MultiHeadMatmulV3Pattern multihead_pattern(pattern, name_scope);
multihead_pattern();
// Create New OpDesc
......@@ -1155,30 +1426,6 @@ static int BuildFusionV3(Graph* graph, const std::string& name_scope,
return fusion_count;
}
} // namespace patterns
void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = patterns::BuildFusion(graph, name_scope_);
AddStatis(fusion_count);
}
void MultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"During the multiheadMatmul pass, The scope should not be null."));
int fusion_count = patterns::BuildFusionV2(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kMultiheadMatmulPass, new bool(true));
}
AddStatis(fusion_count);
}
void MultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
......@@ -1187,7 +1434,7 @@ void MultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
platform::errors::Fatal(
"During the multiheadMatmul pass, The scope should not be null."));
int fusion_count = patterns::BuildFusionV3(graph, name_scope_, scope);
int fusion_count = BuildFusionV3(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kMultiheadMatmulPass, new bool(true));
}
......
......@@ -18,16 +18,6 @@
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
......@@ -158,22 +148,30 @@ class MultiHeadMatmulFusePass : public FusePassBase {
class MultiHeadMatmulV2FusePass : public FusePassBase {
public:
virtual ~MultiHeadMatmulV2FusePass() {}
MultiHeadMatmulV2FusePass();
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse_v2"};
private:
int BuildFusionV2(Graph* graph, const std::string& name_scope,
Scope* scope) const;
};
class MultiHeadMatmulV3FusePass : public FusePassBase {
public:
virtual ~MultiHeadMatmulV3FusePass() {}
MultiHeadMatmulV3FusePass();
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse_v3"};
private:
int BuildFusionV3(Graph* graph, const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
......
......@@ -64,7 +64,7 @@ TEST(MultiHeadMatmulFusePass, basic) {
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) mul -> mul_qkv
Layers layers;
auto* x = layers.data("x", {128, 768});
auto* x = layers.data("x", {1, 128, 768});
auto out = layers.layer_norm(x);
auto* layer_out = out[0];
......@@ -72,41 +72,41 @@ TEST(MultiHeadMatmulFusePass, basic) {
auto* weights_1 = layers.data("weights1", {768, 768}, true);
auto* weights_2 = layers.data("weights2", {768, 768}, true);
auto* mul_out_0 = layers.mul(layer_out, weights_0);
auto* mul_out_1 = layers.mul(layer_out, weights_1);
auto* mul_out_2 = layers.mul(layer_out, weights_2);
auto* mul_out_0 = layers.mul(layer_out, weights_0, nullptr, 2);
auto* mul_out_1 = layers.mul(layer_out, weights_1, nullptr, 2);
auto* mul_out_2 = layers.mul(layer_out, weights_2, nullptr, 2);
auto* b0 = layers.data("bias_0", {768}, true);
auto* b1 = layers.data("bias_1", {768}, true);
auto* b2 = layers.data("bias_2", {768}, true);
auto* elementwise_out_0 = layers.elementwise_add(mul_out_0, b0);
auto* elementwise_out_1 = layers.elementwise_add(mul_out_1, b1);
auto* elementwise_out_2 = layers.elementwise_add(mul_out_2, b2);
auto* elementwise_out_0 = layers.elementwise_add(mul_out_0, b0, nullptr, 2);
auto* elementwise_out_1 = layers.elementwise_add(mul_out_1, b1, nullptr, 2);
auto* elementwise_out_2 = layers.elementwise_add(mul_out_2, b2, nullptr, 2);
std::vector<int> shape = {128, 12, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape);
auto* reshape_1 = layers.reshape2(elementwise_out_1, shape);
auto* reshape_2 = layers.reshape2(elementwise_out_2, shape);
std::vector<int> shape = {1, 128, 12, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true);
auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis);
auto* transpose_1 = layers.transpose2(reshape_1, axis);
auto* transpose_2 = layers.transpose2(reshape_2, axis);
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto* transpose_1 = layers.transpose2(reshape_1, axis, true);
auto* transpose_2 = layers.transpose2(reshape_2, axis, true);
auto* scale_0 = layers.scale(transpose_0, 0.125, 0, false);
auto* matmul_qk = layers.matmul(scale_0, transpose_1);
auto* matmul_qk = layers.matmul(scale_0, transpose_1, nullptr, false, true);
auto* bqk = layers.data("biasqk", {768}, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* matmul_qkv = layers.matmul(softmax_qk, transpose_2);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3});
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {128, 768});
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 768}, true);
auto* weights_l = layers.data("weightsl", {768, 768}, true);
layers.mul(reshape_qkv_out, weights_l);
layers.mul(reshape_qkv_out, weights_l, nullptr, 2);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
......
......@@ -293,13 +293,17 @@ struct Layers {
return outs;
}
VarDesc* matmul(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr) {
VarDesc* matmul(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr,
bool transpose_x = false, bool transpose_y = false) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("matmul");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("transpose_X", transpose_x);
op->SetAttr("transpose_Y", transpose_y);
op->SetAttr("alpha", 1.0f);
return out;
}
......
......@@ -23,6 +23,10 @@ def {
}
}
extra {
attrs {
name: "head_number"
type: INT
}
attrs {
name: "Scale_out"
type: FLOAT
......
......@@ -10,12 +10,12 @@ def {
name: "axis"
type: INT
}
}
extra {
attrs {
name: "data_format"
type: STRING
}
}
extra {
attrs {
name: "op_role"
type: INT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册