未验证 提交 8d4450db 编写于 作者: R RichardWooSJTU 提交者: GitHub

update fused_multi_transformer_encoder_pass support GPT new matmul API (#48953)

* fit paddle.matmul in fleetx.gpt
上级 41d27818
...@@ -478,7 +478,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -478,7 +478,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split") ->assert_is_op_output("split")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul", "X"); ->assert_is_op_input("matmul_v2", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr()) auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split") ->assert_is_op_output("split")
->AsIntermediate() ->AsIntermediate()
...@@ -496,7 +496,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -496,7 +496,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr()) auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr())
->assert_is_op_output("concat") ->assert_is_op_output("concat")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul") ->assert_is_op_input("matmul_v2")
->assert_is_op_input("assign"); ->assert_is_op_input("assign");
auto* concat_v_in_var = pattern auto* concat_v_in_var = pattern
->NewNode(concat_v_in_repr()) ->NewNode(concat_v_in_repr())
...@@ -529,10 +529,16 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -529,10 +529,16 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
assign_v->LinksFrom({concat_v_out_var}); assign_v->LinksFrom({concat_v_out_var});
// QK path Nodes // QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2");
auto* matmul_qk_out_var = auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale");
auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale");
auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr())
->assert_is_op_output("scale")
->AsIntermediate()
->assert_is_op_input("elementwise_add", "X");
auto* eltadd_qk = auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
...@@ -554,7 +560,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -554,7 +560,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
// QK path Linsk // QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var}) matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var})
.LinksTo({matmul_qk_out_var}); .LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var});
eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var}); .LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
...@@ -799,7 +806,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -799,7 +806,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split") ->assert_is_op_output("split")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul", "X"); ->assert_is_op_input("matmul_v2", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr()) auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split") ->assert_is_op_output("split")
->AsIntermediate() ->AsIntermediate()
...@@ -817,7 +824,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -817,7 +824,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr()) auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr())
->assert_is_op_output("concat") ->assert_is_op_output("concat")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul") ->assert_is_op_input("matmul_v2")
->assert_is_op_input("assign"); ->assert_is_op_input("assign");
auto* concat_v_in_var = pattern auto* concat_v_in_var = pattern
->NewNode(concat_v_in_repr()) ->NewNode(concat_v_in_repr())
...@@ -850,10 +857,16 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -850,10 +857,16 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
assign_v->LinksFrom({concat_v_out_var}); assign_v->LinksFrom({concat_v_out_var});
// QK path Nodes // QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2");
auto* matmul_qk_out_var = auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale");
auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale");
auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr())
->assert_is_op_output("scale")
->AsIntermediate()
->assert_is_op_input("elementwise_add", "X");
auto* eltadd_qk = auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
...@@ -875,7 +888,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -875,7 +888,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
// QK path Linsk // QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var}) matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var})
.LinksTo({matmul_qk_out_var}); .LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var});
eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var}); .LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
...@@ -2192,6 +2206,11 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2192,6 +2206,11 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
...@@ -2296,6 +2315,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2296,6 +2315,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
assign_v, assign_v,
matmul_qk, matmul_qk,
matmul_qk_out, matmul_qk_out,
scale_qk,
scale_qk_out,
eltadd_qk, eltadd_qk,
eltadd_qk_out, eltadd_qk_out,
softmax_qk, softmax_qk,
...@@ -2382,6 +2403,23 @@ FusedMultiTransformerDecoderFuseQKVPass:: ...@@ -2382,6 +2403,23 @@ FusedMultiTransformerDecoderFuseQKVPass::
.IsNumGT(0) .IsNumGT(0)
.End(); .End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2")) AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H) .AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor() .IsTensor()
...@@ -2917,6 +2955,11 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2917,6 +2955,11 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
...@@ -3031,6 +3074,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -3031,6 +3074,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
assign_v, assign_v,
matmul_qk, matmul_qk,
matmul_qk_out, matmul_qk_out,
scale_qk,
scale_qk_out,
eltadd_qk, eltadd_qk,
eltadd_qk_out, eltadd_qk_out,
softmax_qk, softmax_qk,
...@@ -3124,6 +3169,23 @@ MultiDevicesFusedMultiTransformerDecoderFuseQKVPass:: ...@@ -3124,6 +3169,23 @@ MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::
.IsNumGT(0) .IsNumGT(0)
.End(); .End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2")) AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H) .AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor() .IsTensor()
......
...@@ -182,6 +182,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase { ...@@ -182,6 +182,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
// Q, K matmul // Q, K matmul
PATTERN_DECL_NODE(matmul_qk); PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out); PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(scale_qk);
PATTERN_DECL_NODE(scale_qk_out);
PATTERN_DECL_NODE(eltadd_qk); PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b); PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
...@@ -282,6 +284,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern ...@@ -282,6 +284,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
// Q, K matmul // Q, K matmul
PATTERN_DECL_NODE(matmul_qk); PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out); PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(scale_qk);
PATTERN_DECL_NODE(scale_qk_out);
PATTERN_DECL_NODE(eltadd_qk); PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b); PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
......
...@@ -243,8 +243,9 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -243,8 +243,9 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
// (split_v) concat -> concat_v // (split_v) concat -> concat_v
// (concat_k) assign -> assign_k // (concat_k) assign -> assign_k
// (concat_v) assign -> assign_v // (concat_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk // (split_q, split_k) matmul_v2 -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk) scale -> scale_qk
// (scale_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
...@@ -298,10 +299,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -298,10 +299,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
layers.assign(concat_v); layers.assign(concat_v);
// MHA: QK matmul // MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, concat_k, nullptr, false, true); auto* matmul_qk = layers.matmul_v2(split_q, concat_k, nullptr, false, true);
auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
// MHA: QKV matmul // MHA: QKV matmul
...@@ -361,11 +363,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -361,11 +363,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
num_nodes_before, num_nodes_before,
num_nodes_after + 50, num_nodes_after + 52,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, " "After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d", "The node num in graph should be %d, but the result is %d",
num_nodes_before - 50, num_nodes_before - 52,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
...@@ -396,8 +398,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -396,8 +398,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
// (split_v) concat -> concat_v // (split_v) concat -> concat_v
// (concat_k) assign -> assign_k // (concat_k) assign -> assign_k
// (concat_v) assign -> assign_v // (concat_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk // (split_q, split_k) matmul_v2 -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk) scale -> scale_qk
// (scale_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
...@@ -455,10 +458,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -455,10 +458,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
layers.assign(concat_v); layers.assign(concat_v);
// MHA: QK matmul // MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, concat_k, nullptr, false, true); auto* matmul_qk = layers.matmul_v2(split_q, concat_k, nullptr, false, true);
auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
// MHA: QKV matmul // MHA: QKV matmul
...@@ -523,11 +527,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -523,11 +527,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
num_nodes_before, num_nodes_before,
num_nodes_after + 58, num_nodes_after + 60,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, " "After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d", "The node num in graph should be %d, but the result is %d",
num_nodes_before - 58, num_nodes_before - 60,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
......
...@@ -472,11 +472,11 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -472,11 +472,11 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split") ->assert_is_op_output("split")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul", "X"); ->assert_is_op_input("matmul_v2", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr()) auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split") ->assert_is_op_output("split")
->AsOutput() ->AsOutput()
->assert_is_op_input("matmul", "Y") ->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_input("while"); ->assert_is_op_input("while");
auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr()) auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr())
->assert_is_op_output("split") ->assert_is_op_output("split")
...@@ -499,10 +499,17 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -499,10 +499,17 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
while0->LinksFrom({split0_k_out_var, split0_v_out_var}); while0->LinksFrom({split0_k_out_var, split0_v_out_var});
// QK path Nodes // QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2");
auto* matmul_qk_out_var = auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale");
auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale");
auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr())
->assert_is_op_output("scale")
->AsIntermediate()
->assert_is_op_input("elementwise_add", "X");
auto* eltadd_qk = auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
...@@ -524,7 +531,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -524,7 +531,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
// QK path Linsk // QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var}) matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var})
.LinksTo({matmul_qk_out_var}); .LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var});
eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var}); .LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
...@@ -769,11 +777,11 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -769,11 +777,11 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split") ->assert_is_op_output("split")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul", "X"); ->assert_is_op_input("matmul_v2", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr()) auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split") ->assert_is_op_output("split")
->AsOutput() ->AsOutput()
->assert_is_op_input("matmul", "Y") ->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_input("while"); ->assert_is_op_input("while");
auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr()) auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr())
->assert_is_op_output("split") ->assert_is_op_output("split")
...@@ -796,10 +804,17 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -796,10 +804,17 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
while0->LinksFrom({split0_k_out_var, split0_v_out_var}); while0->LinksFrom({split0_k_out_var, split0_v_out_var});
// QK path Nodes // QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2");
auto* matmul_qk_out_var = auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale");
auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale");
auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr())
->assert_is_op_output("scale")
->AsIntermediate()
->assert_is_op_input("elementwise_add", "X");
auto* eltadd_qk = auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
...@@ -821,7 +836,8 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -821,7 +836,8 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
// QK path Linsk // QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var}) matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var})
.LinksTo({matmul_qk_out_var}); .LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var});
eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var}); .LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
...@@ -2637,6 +2653,11 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2637,6 +2653,11 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
...@@ -2739,6 +2760,8 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2739,6 +2760,8 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
split0_v_out, split0_v_out,
matmul_qk, matmul_qk,
matmul_qk_out, matmul_qk_out,
scale_qk,
scale_qk_out,
eltadd_qk, eltadd_qk,
eltadd_qk_out, eltadd_qk_out,
softmax_qk, softmax_qk,
...@@ -2826,6 +2849,23 @@ FusedMultiTransformerEncoderFuseQKVPass:: ...@@ -2826,6 +2849,23 @@ FusedMultiTransformerEncoderFuseQKVPass::
.IsNumGT(0) .IsNumGT(0)
.End(); .End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2")) AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H) .AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor() .IsTensor()
...@@ -3468,6 +3508,11 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3468,6 +3508,11 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
...@@ -3580,6 +3625,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3580,6 +3625,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
split0_v_out, split0_v_out,
matmul_qk, matmul_qk,
matmul_qk_out, matmul_qk_out,
scale_qk,
scale_qk_out,
eltadd_qk, eltadd_qk,
eltadd_qk_out, eltadd_qk_out,
softmax_qk, softmax_qk,
...@@ -3675,6 +3722,23 @@ MultiDevicesFusedMultiTransformerEncoderFuseQKVPass:: ...@@ -3675,6 +3722,23 @@ MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::
.IsNumGT(0) .IsNumGT(0)
.End(); .End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2")) AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H) .AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor() .IsTensor()
......
...@@ -168,6 +168,8 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { ...@@ -168,6 +168,8 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
// Q, K matmul // Q, K matmul
PATTERN_DECL_NODE(matmul_qk); PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out); PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(scale_qk);
PATTERN_DECL_NODE(scale_qk_out);
PATTERN_DECL_NODE(eltadd_qk); PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b); PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
...@@ -263,6 +265,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern ...@@ -263,6 +265,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
// Q, K matmul // Q, K matmul
PATTERN_DECL_NODE(matmul_qk); PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out); PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(scale_qk);
PATTERN_DECL_NODE(scale_qk_out);
PATTERN_DECL_NODE(eltadd_qk); PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b); PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
......
...@@ -236,9 +236,9 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -236,9 +236,9 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
// (transpose_0) split -> split_q, split_k, // (transpose_0) split -> split_q, split_k,
// split_v (split_k) assign -> assign_k // split_v (split_k) assign -> assign_k
// (split_v) assign -> assign_v // (split_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk // (split_q, split_k) matmul_v2 -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk) scale -> scale_qk
// (eltadd_qk) softmax -> softmax_qk // (scale_qk, eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
...@@ -289,10 +289,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -289,10 +289,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
layers.while_loop({split_k, split_v}); layers.while_loop({split_k, split_v});
// MHA: QK matmul // MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, split_k, nullptr, false, true); auto* matmul_qk = layers.matmul_v2(split_q, split_k, nullptr, false, true);
auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
// MHA: QKV matmul // MHA: QKV matmul
...@@ -352,11 +353,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -352,11 +353,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
num_nodes_before, num_nodes_before,
num_nodes_after + 44, num_nodes_after + 46,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, " "After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d", "The node num in graph should be %d, but the result is %d",
num_nodes_before - 44, num_nodes_before - 46,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
...@@ -385,8 +386,9 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -385,8 +386,9 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
// (transpose_0) split -> split_q, split_k, // (transpose_0) split -> split_q, split_k,
// split_v (split_k) assign -> assign_k // split_v (split_k) assign -> assign_k
// (split_v) assign -> assign_v // (split_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk // (split_q, split_k) matmul_v2 -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk) scale -> scale_qk
// (scale_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
...@@ -442,10 +444,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -442,10 +444,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
layers.while_loop({split_k, split_v}); layers.while_loop({split_k, split_v});
// MHA: QK matmul // MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, split_k, nullptr, false, true); auto* matmul_qk = layers.matmul_v2(split_q, split_k, nullptr, false, true);
auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
// MHA: QKV matmul // MHA: QKV matmul
...@@ -510,11 +513,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -510,11 +513,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
num_nodes_before, num_nodes_before,
num_nodes_after + 52, num_nodes_after + 54,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, " "After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d", "The node num in graph should be %d, but the result is %d",
num_nodes_before - 52, num_nodes_before - 54,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册