未验证 提交 8d4450db 编写于 作者: R RichardWooSJTU 提交者: GitHub

update fused_multi_transformer_encoder_pass support GPT new matmul API (#48953)

* fit paddle.matmul in fleetx.gpt
上级 41d27818
......@@ -478,7 +478,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
->assert_is_op_input("matmul_v2", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
......@@ -496,7 +496,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr())
->assert_is_op_output("concat")
->AsIntermediate()
->assert_is_op_input("matmul")
->assert_is_op_input("matmul_v2")
->assert_is_op_input("assign");
auto* concat_v_in_var = pattern
->NewNode(concat_v_in_repr())
......@@ -529,10 +529,16 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
assign_v->LinksFrom({concat_v_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale");
auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale");
auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr())
->assert_is_op_output("scale")
->AsIntermediate()
->assert_is_op_input("elementwise_add", "X");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
......@@ -554,7 +560,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var});
eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
......@@ -799,7 +806,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
->assert_is_op_input("matmul_v2", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
......@@ -817,7 +824,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr())
->assert_is_op_output("concat")
->AsIntermediate()
->assert_is_op_input("matmul")
->assert_is_op_input("matmul_v2")
->assert_is_op_input("assign");
auto* concat_v_in_var = pattern
->NewNode(concat_v_in_repr())
......@@ -850,10 +857,16 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
assign_v->LinksFrom({concat_v_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale");
auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale");
auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr())
->assert_is_op_output("scale")
->AsIntermediate()
->assert_is_op_input("elementwise_add", "X");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
......@@ -875,7 +888,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var});
eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
......@@ -2192,6 +2206,11 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
......@@ -2296,6 +2315,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
assign_v,
matmul_qk,
matmul_qk_out,
scale_qk,
scale_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
......@@ -2382,6 +2403,23 @@ FusedMultiTransformerDecoderFuseQKVPass::
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
......@@ -2917,6 +2955,11 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
......@@ -3031,6 +3074,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
assign_v,
matmul_qk,
matmul_qk_out,
scale_qk,
scale_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
......@@ -3124,6 +3169,23 @@ MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
......
......@@ -182,6 +182,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(scale_qk);
PATTERN_DECL_NODE(scale_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
......@@ -282,6 +284,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(scale_qk);
PATTERN_DECL_NODE(scale_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
......
......@@ -243,8 +243,9 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
// (split_v) concat -> concat_v
// (concat_k) assign -> assign_k
// (concat_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (split_q, split_k) matmul_v2 -> matmul_qk
// (matmul_qk) scale -> scale_qk
// (scale_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
......@@ -298,10 +299,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
layers.assign(concat_v);
// MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, concat_k, nullptr, false, true);
auto* matmul_qk = layers.matmul_v2(split_q, concat_k, nullptr, false, true);
auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
// MHA: QKV matmul
......@@ -361,11 +363,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 50,
num_nodes_after + 52,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d",
num_nodes_before - 50,
num_nodes_before - 52,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
......@@ -396,8 +398,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
// (split_v) concat -> concat_v
// (concat_k) assign -> assign_k
// (concat_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (split_q, split_k) matmul_v2 -> matmul_qk
// (matmul_qk) scale -> scale_qk
// (scale_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
......@@ -455,10 +458,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
layers.assign(concat_v);
// MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, concat_k, nullptr, false, true);
auto* matmul_qk = layers.matmul_v2(split_q, concat_k, nullptr, false, true);
auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
// MHA: QKV matmul
......@@ -523,11 +527,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 58,
num_nodes_after + 60,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d",
num_nodes_before - 58,
num_nodes_before - 60,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
......
......@@ -472,11 +472,11 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
->assert_is_op_input("matmul_v2", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsOutput()
->assert_is_op_input("matmul", "Y")
->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_input("while");
auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr())
->assert_is_op_output("split")
......@@ -499,10 +499,17 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
while0->LinksFrom({split0_k_out_var, split0_v_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale");
auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale");
auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr())
->assert_is_op_output("scale")
->AsIntermediate()
->assert_is_op_input("elementwise_add", "X");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
......@@ -524,7 +531,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var});
eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
......@@ -769,11 +777,11 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
->assert_is_op_input("matmul_v2", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsOutput()
->assert_is_op_input("matmul", "Y")
->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_input("while");
auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr())
->assert_is_op_output("split")
......@@ -796,10 +804,17 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
while0->LinksFrom({split0_k_out_var, split0_v_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale");
auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale");
auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr())
->assert_is_op_output("scale")
->AsIntermediate()
->assert_is_op_input("elementwise_add", "X");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
......@@ -821,7 +836,8 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var});
eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
......@@ -2637,6 +2653,11 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
......@@ -2739,6 +2760,8 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
split0_v_out,
matmul_qk,
matmul_qk_out,
scale_qk,
scale_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
......@@ -2826,6 +2849,23 @@ FusedMultiTransformerEncoderFuseQKVPass::
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
......@@ -3468,6 +3508,11 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
......@@ -3580,6 +3625,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
split0_v_out,
matmul_qk,
matmul_qk_out,
scale_qk,
scale_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
......@@ -3675,6 +3722,23 @@ MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
......
......@@ -168,6 +168,8 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(scale_qk);
PATTERN_DECL_NODE(scale_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
......@@ -263,6 +265,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(scale_qk);
PATTERN_DECL_NODE(scale_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
......
......@@ -236,9 +236,9 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) assign -> assign_k
// (split_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (split_q, split_k) matmul_v2 -> matmul_qk
// (matmul_qk) scale -> scale_qk
// (scale_qk, eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
......@@ -289,10 +289,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
layers.while_loop({split_k, split_v});
// MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, split_k, nullptr, false, true);
auto* matmul_qk = layers.matmul_v2(split_q, split_k, nullptr, false, true);
auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
// MHA: QKV matmul
......@@ -352,11 +353,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 44,
num_nodes_after + 46,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d",
num_nodes_before - 44,
num_nodes_before - 46,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
......@@ -385,8 +386,9 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) assign -> assign_k
// (split_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (split_q, split_k) matmul_v2 -> matmul_qk
// (matmul_qk) scale -> scale_qk
// (scale_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
......@@ -442,10 +444,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
layers.while_loop({split_k, split_v});
// MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, split_k, nullptr, false, true);
auto* matmul_qk = layers.matmul_v2(split_q, split_k, nullptr, false, true);
auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
// MHA: QKV matmul
......@@ -510,11 +513,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 52,
num_nodes_after + 54,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d",
num_nodes_before - 52,
num_nodes_before - 54,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册