未验证 提交 29eec2dd 编写于 作者: L lzy 提交者: GitHub

add multi_devices_fused_multi_transformer_encoder_pass and cherry-pick from 48349 (#49383)

上级 a2d7e1d7
...@@ -31,6 +31,8 @@ namespace framework { ...@@ -31,6 +31,8 @@ namespace framework {
namespace ir { namespace ir {
namespace patterns { namespace patterns {
static const std::unordered_set<std::string> FFN_ACTS{"relu", "gelu"};
PDNode* FusedMultiTransformerDecoderPattern::operator()() { PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr()); auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X"); input0->assert_is_op_input("layer_norm", "X");
...@@ -359,11 +361,11 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { ...@@ -359,11 +361,11 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("gelu"); ->assert_is_ops_input(FFN_ACTS);
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
->assert_is_op_output("gelu") ->assert_is_ops_output(FFN_ACTS)
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul_v2"); ->assert_is_op_input("matmul_v2");
...@@ -397,8 +399,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { ...@@ -397,8 +399,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
.LinksTo({ffn_matmul0_out_var}); .LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var}); .LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var}); .LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var}); .LinksTo({ffn_eltadd1_out_var});
...@@ -678,11 +680,11 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -678,11 +680,11 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("gelu"); ->assert_is_ops_input(FFN_ACTS);
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
->assert_is_op_output("gelu") ->assert_is_ops_output(FFN_ACTS)
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul_v2"); ->assert_is_op_input("matmul_v2");
...@@ -716,8 +718,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -716,8 +718,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.LinksTo({ffn_matmul0_out_var}); .LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var}); .LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var}); .LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var}); .LinksTo({ffn_eltadd1_out_var});
...@@ -1026,11 +1028,11 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -1026,11 +1028,11 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("gelu"); ->assert_is_ops_input(FFN_ACTS);
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
->assert_is_op_output("gelu") ->assert_is_ops_output(FFN_ACTS)
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul_v2"); ->assert_is_op_input("matmul_v2");
...@@ -1073,8 +1075,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -1073,8 +1075,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.LinksTo({ffn_matmul0_out_var}); .LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var}); .LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var}); .LinksTo({ffn_matmul1_out_var});
ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var}) ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var})
.LinksTo({ffn_c_allreduce_sum_out_var}); .LinksTo({ffn_c_allreduce_sum_out_var});
...@@ -1147,6 +1149,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1147,6 +1149,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_act,
Node* ffn_output) { Node* ffn_output) {
auto* matmul0_op = matmul0->Op(); auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op(); auto* matmul_linear_op = matmul_linear->Op();
...@@ -1215,6 +1218,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1215,6 +1218,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr( fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon")); "epsilon", layer_norm->Op()->GetAttr("epsilon"));
fused_multi_transformer_op_desc.SetAttr("act_method",
ffn_act->Op()->Type());
// output dropout attribute // output dropout attribute
fused_multi_transformer_op_desc.SetAttr("is_test", true); fused_multi_transformer_op_desc.SetAttr("is_test", true);
...@@ -1455,9 +1460,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1455,9 +1460,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern); ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_pattern); ffn_act, ffn_act, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_pattern); ffn_act_out, ffn_act_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern); ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern);
...@@ -1578,6 +1583,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1578,6 +1583,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
ffn_act,
ffn_output); ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm, std::unordered_set<const Node*> marked_nodes({layer_norm,
...@@ -1644,8 +1650,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1644,8 +1650,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd1, ffn_eltadd1,
ffn_eltadd0_out, ffn_eltadd0_out,
ffn_eltadd1_out, ffn_eltadd1_out,
ffn_gelu, ffn_act,
ffn_gelu_out, ffn_act_out,
ffn_eltadd_out}); ffn_eltadd_out});
// Remove unneeded nodes. // Remove unneeded nodes.
...@@ -1871,6 +1877,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -1871,6 +1877,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_act,
Node* ffn_output) { Node* ffn_output) {
auto* matmul0_op = matmul0->Op(); auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op(); auto* matmul_linear_op = matmul_linear->Op();
...@@ -1939,6 +1946,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -1939,6 +1946,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr( fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon")); "epsilon", layer_norm->Op()->GetAttr("epsilon"));
fused_multi_transformer_op_desc.SetAttr("act_method",
ffn_act->Op()->Type());
// output dropout attribute // output dropout attribute
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
...@@ -2168,9 +2177,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2168,9 +2177,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern); ffn_act, ffn_act, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern); ffn_act_out, ffn_act_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern); ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
...@@ -2287,6 +2296,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2287,6 +2296,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
ffn_act,
ffn_output); ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm, std::unordered_set<const Node*> marked_nodes({layer_norm,
...@@ -2345,8 +2355,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2345,8 +2355,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1, ffn_eltadd1,
ffn_eltadd0_out, ffn_eltadd0_out,
ffn_eltadd1_out, ffn_eltadd1_out,
ffn_gelu, ffn_act,
ffn_gelu_out, ffn_act_out,
ffn_eltadd_out}); ffn_eltadd_out});
// Remove unneeded nodes. // Remove unneeded nodes.
...@@ -2592,6 +2602,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2592,6 +2602,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_act,
Node* ffn_output) { Node* ffn_output) {
auto* matmul_linear_op = matmul_linear->Op(); auto* matmul_linear_op = matmul_linear->Op();
auto* ffn_matmul_1_op = ffn_matmul1->Op(); auto* ffn_matmul_1_op = ffn_matmul1->Op();
...@@ -2658,6 +2669,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2658,6 +2669,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr( fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon")); "epsilon", layer_norm->Op()->GetAttr("epsilon"));
fused_multi_transformer_op_desc.SetAttr("act_method",
ffn_act->Op()->Type());
// output dropout attribute // output dropout attribute
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
...@@ -2911,9 +2924,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2911,9 +2924,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern); ffn_act, ffn_act, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern); ffn_act_out, ffn_act_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern); ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
...@@ -3044,6 +3057,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -3044,6 +3057,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
ffn_act,
ffn_output); ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm, std::unordered_set<const Node*> marked_nodes({layer_norm,
...@@ -3110,8 +3124,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -3110,8 +3124,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1, ffn_eltadd1,
ffn_eltadd0_out, ffn_eltadd0_out,
ffn_eltadd1_out, ffn_eltadd1_out,
ffn_gelu, ffn_act,
ffn_gelu_out, ffn_act_out,
ffn_eltadd_out}); ffn_eltadd_out});
// Remove unneeded nodes. // Remove unneeded nodes.
......
...@@ -125,8 +125,8 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase { ...@@ -125,8 +125,8 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out); PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu); PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_gelu_out); PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1); PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w); PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out); PATTERN_DECL_NODE(ffn_matmul1_out);
...@@ -223,8 +223,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase { ...@@ -223,8 +223,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out); PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu); PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_gelu_out); PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1); PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w); PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out); PATTERN_DECL_NODE(ffn_matmul1_out);
...@@ -329,8 +329,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern ...@@ -329,8 +329,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out); PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu); PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_gelu_out); PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1); PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w); PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out); PATTERN_DECL_NODE(ffn_matmul1_out);
......
...@@ -25,44 +25,20 @@ namespace framework { ...@@ -25,44 +25,20 @@ namespace framework {
namespace ir { namespace ir {
namespace patterns { namespace patterns {
PDNode* FusedMultiTransformerEncoderPattern::operator()() { static const std::unordered_set<std::string> FFN_ACTS{"relu", "gelu"};
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
// pre-LayerNorm PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto* layer_norm = auto* input0 = pattern->NewNode(input0_repr())
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X") ->assert_is_op_input("matmul_v2", "X")
->assert_is_op_input("elementwise_add", "X")
->assert_more([](Node* x) { ->assert_more([](Node* x) {
if (x->outputs.size() == 3) { if (x->outputs.size() == 4) {
return true; return true;
} else { } else {
return false; return false;
} }
}); });
layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
// Q path Nodes // Q path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2"); auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr()) auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
...@@ -95,15 +71,20 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -95,15 +71,20 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2") ->assert_is_op_output("transpose2")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul", "X"); ->assert_is_op_input("scale");
auto* scale_q = pattern->NewNode(scale_q_repr())->assert_is_op("scale");
auto* scale_q_out_var = pattern->NewNode(scale_q_out_repr())
->assert_is_op_output("scale")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X");
// Q path Links // Q path Links
matmul0->LinksFrom({layer_norm_out_var, matmul0_w_var}) matmul0->LinksFrom({input0, matmul0_w_var}).LinksTo({matmul0_out_var});
.LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var}) eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var}); .LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
scale_q->LinksFrom({transpose2_0_out_var}).LinksTo({scale_q_out_var});
// K path Nodes // K path Nodes
auto* matmul1 = pattern->NewNode(matmul1_repr())->assert_is_op("matmul_v2"); auto* matmul1 = pattern->NewNode(matmul1_repr())->assert_is_op("matmul_v2");
...@@ -137,20 +118,11 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -137,20 +118,11 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2") ->assert_is_op_output("transpose2")
->AsOutput() ->AsIntermediate()
->assert_is_op_input("matmul", "Y") ->assert_is_op_input("matmul_v2", "Y");
->assert_is_op_input("while")
->assert_more([](Node* x) {
if (x->outputs.size() == 2) {
return true;
} else {
return false;
}
});
// K path Links // K path Links
matmul1->LinksFrom({layer_norm_out_var, matmul1_w_var}) matmul1->LinksFrom({input0, matmul1_w_var}).LinksTo({matmul1_out_var});
.LinksTo({matmul1_out_var});
eltadd1->LinksFrom({matmul1_out_var, eltadd1_b_var}) eltadd1->LinksFrom({matmul1_out_var, eltadd1_b_var})
.LinksTo({eltadd1_out_var}); .LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
...@@ -187,29 +159,21 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -187,29 +159,21 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2") ->assert_is_op_output("transpose2")
->AsOutput() ->AsIntermediate()
->assert_is_op_input("matmul_v2", "Y") ->assert_is_op_input("matmul_v2", "Y");
->assert_is_op_input("while")
->assert_more([](Node* x) {
if (x->outputs.size() == 2) {
return true;
} else {
return false;
}
});
// V path Links // V path Links
matmul2->LinksFrom({layer_norm_out_var, matmul2_w_var}) matmul2->LinksFrom({input0, matmul2_w_var}).LinksTo({matmul2_out_var});
.LinksTo({matmul2_out_var});
eltadd2->LinksFrom({matmul2_out_var, eltadd2_b_var}) eltadd2->LinksFrom({matmul2_out_var, eltadd2_b_var})
.LinksTo({eltadd2_out_var}); .LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var}); reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
// QK path Nodes // QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2");
auto* matmul_qk_out_var = auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk = auto* eltadd_qk =
...@@ -230,7 +194,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -230,7 +194,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
->assert_is_op_input("matmul_v2", "X"); ->assert_is_op_input("matmul_v2", "X");
// QK path Linsk // QK path Linsk
matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var}) matmul_qk->LinksFrom({scale_q_out_var, transpose2_1_out_var})
.LinksTo({matmul_qk_out_var}); .LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var}); .LinksTo({eltadd_qk_out_var});
...@@ -297,42 +261,41 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -297,42 +261,41 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
eltadd_out->LinksFrom({input0, eltadd_linear_out_var}) eltadd_out->LinksFrom({input0, eltadd_linear_out_var})
.LinksTo({attention_output}); .LinksTo({attention_output});
// while loop // post-LayerNorm
auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while"); auto* layer_norm =
while0->LinksFrom({transpose2_1_out_var, transpose2_2_out_var}); pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
// Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale"); ->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var = auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias"); ->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var = auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("layer_norm", "Mean"); ->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var = auto* layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr()) pattern->NewNode(layer_norm_variance_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("layer_norm", "Variance"); ->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr()) auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate() ->AsIntermediate()
->assert_is_op_output("layer_norm", "Y") ->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X"); ->assert_is_op_input("matmul_v2", "X")
->assert_is_op_input("elementwise_add", "X")
->assert_more([](Node* x) {
if (x->outputs.size() == 2) {
return true;
} else {
return false;
}
});
ffn_layer_norm layer_norm
->LinksFrom( ->LinksFrom({attention_output, layer_norm_bias_var, layer_norm_scale_var})
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var}) .LinksTo(
.LinksTo({ffn_layer_norm_out_var, {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// Feed Forward fc1 -> gelu -> fc2 // Feed Forward fc1 -> gelu -> fc2
auto* ffn_matmul0 = auto* ffn_matmul0 =
...@@ -353,11 +316,11 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -353,11 +316,11 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("gelu"); ->assert_is_ops_input(FFN_ACTS);
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
->assert_is_op_output("gelu") ->assert_is_ops_output(FFN_ACTS)
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul_v2"); ->assert_is_op_input("matmul_v2");
...@@ -385,22 +348,55 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -385,22 +348,55 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add"); pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr()) auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsOutput(); ->AsIntermediate()
->assert_is_op_input("layer_norm");
ffn_matmul0->LinksFrom({ffn_layer_norm_out_var, ffn_matmul0_w_var}) ffn_matmul0->LinksFrom({layer_norm_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var}); .LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var}); .LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var}); .LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var}); .LinksTo({ffn_eltadd1_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var}) ffn_eltadd_out->LinksFrom({layer_norm_out_var, ffn_eltadd1_out_var})
.LinksTo({ffn_output}); .LinksTo({ffn_output});
return ffn_output; // Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Y");
ffn_layer_norm
->LinksFrom(
{ffn_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
.LinksTo({ffn_layer_norm_out_var,
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
return ffn_layer_norm_out_var;
} }
PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...@@ -649,11 +645,11 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -649,11 +645,11 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("gelu"); ->assert_is_ops_input(FFN_ACTS);
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
->assert_is_op_output("gelu") ->assert_is_ops_output(FFN_ACTS)
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul_v2"); ->assert_is_op_input("matmul_v2");
...@@ -687,8 +683,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -687,8 +683,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
.LinksTo({ffn_matmul0_out_var}); .LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var}); .LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var}); .LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var}); .LinksTo({ffn_eltadd1_out_var});
...@@ -699,47 +695,41 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -699,47 +695,41 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
return ffn_output; return ffn_output;
} }
PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { PDNode* MultiDevicesFusedMultiTransformerEncoderPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr()); auto* input0 = pattern->NewNode(input0_repr())
input0->assert_is_op_input("layer_norm", "X"); ->assert_is_op_input("c_identity", "X")
->assert_is_op_input("elementwise_add", "X")
// pre-LayerNorm ->assert_more([](Node* x) {
auto* layer_norm = if (x->outputs.size() == 4) {
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); return true;
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) } else {
->AsInput() return false;
->assert_is_persistable_var() }
->assert_is_op_input("layer_norm", "Scale"); });
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("c_identity", "X");
layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
// communication c_identity // communication c_identity
auto* c_identity = auto* c_identity0 =
pattern->NewNode(c_identity_repr())->assert_is_op("c_identity"); pattern->NewNode(c_identity0_repr())->assert_is_op("c_identity");
auto* c_identity_out_var = pattern->NewNode(c_identity_out_repr()) auto* c_identity0_out_var = pattern->NewNode(c_identity0_out_repr())
->AsIntermediate() ->AsIntermediate()
->assert_is_op_output("c_identity", "Out") ->assert_is_op_output("c_identity", "Out")
->assert_is_op_input("matmul_v2", "X"); ->assert_is_op_input("matmul_v2", "X");
c_identity->LinksFrom({layer_norm_out_var}).LinksTo({c_identity_out_var}); auto* c_identity1 =
pattern->NewNode(c_identity1_repr())->assert_is_op("c_identity");
auto* c_identity1_out_var = pattern->NewNode(c_identity1_out_repr())
->AsIntermediate()
->assert_is_op_output("c_identity", "Out")
->assert_is_op_input("matmul_v2", "X");
auto* c_identity2 =
pattern->NewNode(c_identity2_repr())->assert_is_op("c_identity");
auto* c_identity2_out_var = pattern->NewNode(c_identity2_out_repr())
->AsIntermediate()
->assert_is_op_output("c_identity", "Out")
->assert_is_op_input("matmul_v2", "X");
c_identity0->LinksFrom({input0}).LinksTo({c_identity0_out_var});
c_identity1->LinksFrom({input0}).LinksTo({c_identity1_out_var});
c_identity2->LinksFrom({input0}).LinksTo({c_identity2_out_var});
// QKV fused path Nodes // Q path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2"); auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr()) auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
->AsInput() ->AsInput()
...@@ -771,75 +761,137 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -771,75 +761,137 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2") ->assert_is_op_output("transpose2")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("split", "X"); ->assert_is_op_input("scale");
auto* scale_q = pattern->NewNode(scale_q_repr())->assert_is_op("scale");
auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split"); auto* scale_q_out_var = pattern->NewNode(scale_q_out_repr())
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) ->assert_is_op_output("scale")
->assert_is_op_output("split")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); ->assert_is_op_input("matmul_v2", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsOutput()
->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_input("while");
auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr())
->assert_is_op_output("split")
->AsOutput()
->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_input("while");
// QKV fused path Links // Q path Links
matmul0->LinksFrom({c_identity_out_var, matmul0_w_var}) matmul0->LinksFrom({c_identity0_out_var, matmul0_w_var})
.LinksTo({matmul0_out_var}); .LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var}) eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var}); .LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
split0->LinksFrom({transpose2_0_out_var}) scale_q->LinksFrom({transpose2_0_out_var}).LinksTo({scale_q_out_var});
.LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var});
// while loop
auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while");
while0->LinksFrom({split0_k_out_var, split0_v_out_var});
// QK path Nodes
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale");
auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale"); // K path Nodes
auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr()) auto* matmul1 = pattern->NewNode(matmul1_repr())->assert_is_op("matmul_v2");
->assert_is_op_output("scale") auto* matmul1_w_var = pattern->NewNode(matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul1_out_var = pattern->NewNode(matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("elementwise_add", "X"); ->assert_is_op_input("elementwise_add");
auto* eltadd_qk = auto* eltadd1 =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) auto* eltadd1_b_var = pattern->NewNode(eltadd1_b_repr())
->AsInput() ->AsInput()
->assert_is_op_input("elementwise_add", "Y"); ->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
auto* eltadd1_out_var = pattern->NewNode(eltadd1_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("softmax"); ->assert_is_op_input("reshape2");
auto* softmax_qk = auto* reshape2_1 =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) auto* reshape2_1_out_var = pattern->NewNode(reshape2_1_out_repr())
->assert_is_op_output("softmax") ->assert_is_op_output("reshape2")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); ->assert_is_op_input("transpose2");
// QK path Linsk auto* transpose2_1 =
matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var}) pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
.LinksTo({matmul_qk_out_var}); auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var}); ->assert_is_op_output("transpose2")
eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var}) ->AsIntermediate()
.LinksTo({eltadd_qk_out_var}); ->assert_is_op_input("matmul_v2", "Y");
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
// K path Links
matmul1->LinksFrom({c_identity1_out_var, matmul1_w_var})
.LinksTo({matmul1_out_var});
eltadd1->LinksFrom({matmul1_out_var, eltadd1_b_var})
.LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
// V path Nodes
auto* matmul2 = pattern->NewNode(matmul2_repr())->assert_is_op("matmul_v2");
auto* matmul2_w_var = pattern->NewNode(matmul2_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul2_out_var = pattern->NewNode(matmul2_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd2 =
pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add");
auto* eltadd2_b_var = pattern->NewNode(eltadd2_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd2_out_var = pattern->NewNode(eltadd2_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_2 =
pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2");
auto* reshape2_2_out_var = pattern->NewNode(reshape2_2_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "Y");
// V path Links
matmul2->LinksFrom({c_identity2_out_var, matmul2_w_var})
.LinksTo({matmul2_out_var});
eltadd2->LinksFrom({matmul2_out_var, eltadd2_b_var})
.LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
// QK path Nodes
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X");
// QK path Linsk
matmul_qk->LinksFrom({scale_q_out_var, transpose2_1_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
// QKV path Nodes // QKV path Nodes
auto* matmul_qkv = auto* matmul_qkv =
...@@ -897,7 +949,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -897,7 +949,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
->AsIntermediate(); ->AsIntermediate();
// QKV path Links // QKV path Links
matmul_qkv->LinksFrom({softmax_qk_out_var, split0_v_out_var}) matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var})
.LinksTo({matmul_qkv_out_var}); .LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var}) transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var}); .LinksTo({transpose2_qkv_out_var});
...@@ -912,38 +964,41 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -912,38 +964,41 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
eltadd_out->LinksFrom({input0, eltadd_linear_out_var}) eltadd_out->LinksFrom({input0, eltadd_linear_out_var})
.LinksTo({attention_output}); .LinksTo({attention_output});
// Feed Forward LayerNorm Nodes // post-LayerNorm
auto* ffn_layer_norm = auto* layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm"); pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var = auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale"); ->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var = auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias"); ->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var = auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
pattern->NewNode(ffn_layer_norm_mean_repr()) ->AsOutput()
->AsIntermediate()
->assert_is_op_output("layer_norm", "Mean"); ->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var = auto* layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr()) pattern->NewNode(layer_norm_variance_repr())
->AsIntermediate() ->AsOutput()
->assert_is_op_output("layer_norm", "Variance"); ->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr()) auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate() ->AsIntermediate()
->assert_is_op_output("layer_norm", "Y") ->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("c_identity", "X"); ->assert_is_op_input("c_identity", "X")
->assert_is_op_input("elementwise_add", "X")
->assert_more([](Node* x) {
if (x->outputs.size() == 2) {
return true;
} else {
return false;
}
});
ffn_layer_norm layer_norm
->LinksFrom( ->LinksFrom({attention_output, layer_norm_bias_var, layer_norm_scale_var})
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var}) .LinksTo(
.LinksTo({ffn_layer_norm_out_var, {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// communication c_identity // communication c_identity
auto* ffn_c_identity = auto* ffn_c_identity =
...@@ -952,7 +1007,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -952,7 +1007,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
->assert_is_op_output("c_identity", "Out") ->assert_is_op_output("c_identity", "Out")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); ->assert_is_op_input("matmul_v2", "X");
ffn_c_identity->LinksFrom({ffn_layer_norm_out_var}) ffn_c_identity->LinksFrom({layer_norm_out_var})
.LinksTo({ffn_c_identity_out_var}); .LinksTo({ffn_c_identity_out_var});
// Feed Forward fc1 -> gelu -> fc2 // Feed Forward fc1 -> gelu -> fc2
...@@ -974,11 +1029,11 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -974,11 +1029,11 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("gelu"); ->assert_is_ops_input(FFN_ACTS);
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
->assert_is_op_output("gelu") ->assert_is_ops_output(FFN_ACTS)
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul_v2"); ->assert_is_op_input("matmul_v2");
...@@ -1015,297 +1070,1504 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -1015,297 +1070,1504 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add"); pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr()) auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsOutput(); ->AsIntermediate()
->assert_is_op_input("layer_norm");
ffn_matmul0->LinksFrom({ffn_c_identity_out_var, ffn_matmul0_w_var}) ffn_matmul0->LinksFrom({ffn_c_identity_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var}); .LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var}); .LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var}); .LinksTo({ffn_matmul1_out_var});
ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var}) ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var})
.LinksTo({ffn_c_allreduce_sum_out_var}); .LinksTo({ffn_c_allreduce_sum_out_var});
ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var}) ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var}); .LinksTo({ffn_eltadd1_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var}) ffn_eltadd_out->LinksFrom({layer_norm_out_var, ffn_eltadd1_out_var})
.LinksTo({ffn_output}); .LinksTo({ffn_output});
return ffn_output; // Feed Forward LayerNorm Nodes
} auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
} // namespace patterns auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Y");
template <typename T> ffn_layer_norm
inline void QKVWeightsProcess(phi::DenseTensor* wq_tensor, ->LinksFrom(
phi::DenseTensor* wk_tensor, {ffn_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
phi::DenseTensor* wv_tensor, .LinksTo({ffn_layer_norm_out_var,
const int num_head, ffn_layer_norm_mean_var,
const int dim_head, ffn_layer_norm_variance_var});
const int dim_embed) {
auto* wq_data = wq_tensor->mutable_data<T>(platform::CPUPlace());
auto* wk_data = wk_tensor->mutable_data<T>(platform::CPUPlace());
auto* wv_data = wv_tensor->mutable_data<T>(platform::CPUPlace());
auto combined_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed}); return ffn_layer_norm_out_var;
}
phi::DenseTensor tmp_combined_w_tensor; PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
tmp_combined_w_tensor.Resize(combined_w_dims); auto* input0 = pattern->NewNode(input0_repr());
auto* tmp_combined_w_data = input0->assert_is_op_input("layer_norm", "X");
tmp_combined_w_tensor.mutable_data<T>(platform::CPUPlace());
std::vector<T*> w_vec = {wq_data, wk_data, wv_data}; // pre-LayerNorm
// Combine the three fc weights together. auto* layer_norm =
for (int i = 0; i < 3; i++) { pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
for (int j = 0; j < num_head; j++) { auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
for (int k = 0; k < dim_head; k++) { ->AsInput()
for (int l = 0; l < dim_embed; l++) { ->assert_is_persistable_var()
int out_idx = i * num_head * dim_head * dim_embed + ->assert_is_op_input("layer_norm", "Scale");
j * dim_head * dim_embed + k * dim_embed + l; auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
int in_idx = l * num_head * dim_head + j * dim_head + k; ->AsInput()
tmp_combined_w_data[out_idx] = w_vec[i][in_idx]; ->assert_is_persistable_var()
} ->assert_is_op_input("layer_norm", "Bias");
} auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
} ->AsOutput()
} ->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("c_identity", "X");
wq_tensor->Resize(combined_w_dims); layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
auto* new_combined_w_data = wq_tensor->mutable_data<T>(platform::CPUPlace()); .LinksTo(
memcpy( {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
new_combined_w_data, tmp_combined_w_data, sizeof(T) * wq_tensor->numel());
}
template <typename T> // communication c_identity
inline void QKVBiasProcess(phi::DenseTensor* bq_tensor, auto* c_identity =
phi::DenseTensor* bk_tensor, pattern->NewNode(c_identity_repr())->assert_is_op("c_identity");
phi::DenseTensor* bv_tensor, auto* c_identity_out_var = pattern->NewNode(c_identity_out_repr())
const int num_head, ->AsIntermediate()
const int dim_head, ->assert_is_op_output("c_identity", "Out")
const int dim_embed) { ->assert_is_op_input("matmul_v2", "X");
auto* bq_data = bq_tensor->mutable_data<T>(platform::CPUPlace()); c_identity->LinksFrom({layer_norm_out_var}).LinksTo({c_identity_out_var});
auto* bk_data = bk_tensor->mutable_data<T>(platform::CPUPlace());
auto* bv_data = bv_tensor->mutable_data<T>(platform::CPUPlace());
auto combined_bias_dims = phi::make_ddim({3, num_head, dim_head}); // QKV fused path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
phi::DenseTensor tmp_combined_bias_tensor; auto* eltadd0 =
tmp_combined_bias_tensor.Resize(combined_bias_dims); pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
auto* tmp_combined_bias_data = auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
tmp_combined_bias_tensor.mutable_data<T>(platform::CPUPlace()); ->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
size_t bias_size = bq_tensor->numel(); auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate()
->assert_is_op_input("split", "X");
auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split");
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsOutput()
->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_input("while");
auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr())
->assert_is_op_output("split")
->AsOutput()
->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_input("while");
// QKV fused path Links
matmul0->LinksFrom({c_identity_out_var, matmul0_w_var})
.LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
split0->LinksFrom({transpose2_0_out_var})
.LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var});
// while loop
auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while");
while0->LinksFrom({split0_k_out_var, split0_v_out_var});
// QK path Nodes
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale");
auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale");
auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr())
->assert_is_op_output("scale")
->AsIntermediate()
->assert_is_op_input("elementwise_add", "X");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X");
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var})
.LinksTo({matmul_qk_out_var});
scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var});
eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
// QKV path Nodes
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var =
pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("matmul_v2"); // -> out_linear
auto* matmul_linear =
pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2");
auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("c_allreduce_sum");
// communication c_allreduce_sum
auto* c_allreduce_sum =
pattern->NewNode(c_allreduce_sum_repr())->assert_is_op("c_allreduce_sum");
auto* c_allreduce_sum_out_var = pattern->NewNode(c_allreduce_sum_out_repr())
->assert_is_op_output("c_allreduce_sum")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_linear =
pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add");
auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
auto* attention_output = pattern->NewNode(attention_output_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
// QKV path Links
matmul_qkv->LinksFrom({softmax_qk_out_var, split0_v_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var})
.LinksTo({matmul_linear_out_var});
c_allreduce_sum->LinksFrom({matmul_linear_out_var})
.LinksTo({c_allreduce_sum_out_var});
eltadd_linear->LinksFrom({c_allreduce_sum_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var});
eltadd_out->LinksFrom({input0, eltadd_linear_out_var})
.LinksTo({attention_output});
// Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("c_identity", "X");
ffn_layer_norm
->LinksFrom(
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
.LinksTo({ffn_layer_norm_out_var,
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// communication c_identity
auto* ffn_c_identity =
pattern->NewNode(ffn_c_identity_repr())->assert_is_op("c_identity");
auto* ffn_c_identity_out_var = pattern->NewNode(ffn_c_identity_out_repr())
->assert_is_op_output("c_identity", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X");
ffn_c_identity->LinksFrom({ffn_layer_norm_out_var})
.LinksTo({ffn_c_identity_out_var});
// Feed Forward fc1 -> gelu -> fc2
auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd0 =
pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_ops_input(FFN_ACTS);
auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
->assert_is_ops_output(FFN_ACTS)
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("c_allreduce_sum");
// communication c_allreduce_sum
auto* ffn_c_allreduce_sum = pattern->NewNode(ffn_c_allreduce_sum_repr())
->assert_is_op("c_allreduce_sum");
auto* ffn_c_allreduce_sum_out_var =
pattern->NewNode(ffn_c_allreduce_sum_out_repr())
->assert_is_op_output("c_allreduce_sum")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd1 =
pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out =
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add")
->AsOutput();
ffn_matmul0->LinksFrom({ffn_c_identity_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var})
.LinksTo({ffn_c_allreduce_sum_out_var});
ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var})
.LinksTo({ffn_output});
return ffn_output;
}
} // namespace patterns
template <typename T>
inline void QKVWeightsProcess(phi::DenseTensor* wq_tensor,
phi::DenseTensor* wk_tensor,
phi::DenseTensor* wv_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()));
auto* wq_data = wq_tensor->data<T>();
auto* wk_data = wk_tensor->data<T>();
auto* wv_data = wv_tensor->data<T>();
auto combined_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed});
phi::DenseTensor tmp_combined_w_tensor;
tmp_combined_w_tensor.Resize(combined_w_dims);
dev_ctx->Alloc<T>(&tmp_combined_w_tensor);
auto* tmp_combined_w_data = tmp_combined_w_tensor.data<T>();
std::vector<T*> w_vec = {wq_data, wk_data, wv_data};
// Combine the three fc weights together.
for (int i = 0; i < 3; i++) {
for (int j = 0; j < num_head; j++) {
for (int k = 0; k < dim_head; k++) {
for (int l = 0; l < dim_embed; l++) {
int out_idx = i * num_head * dim_head * dim_embed +
j * dim_head * dim_embed + k * dim_embed + l;
int in_idx = l * num_head * dim_head + j * dim_head + k;
tmp_combined_w_data[out_idx] = w_vec[i][in_idx];
}
}
}
}
wq_tensor->Resize(combined_w_dims);
dev_ctx->Alloc<T>(wq_tensor);
auto* new_combined_w_data = wq_tensor->data<T>();
memcpy(
new_combined_w_data, tmp_combined_w_data, sizeof(T) * wq_tensor->numel());
}
template <typename T>
inline void QKVBiasProcess(phi::DenseTensor* bq_tensor,
phi::DenseTensor* bk_tensor,
phi::DenseTensor* bv_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()));
auto* bq_data = bq_tensor->data<T>();
auto* bk_data = bk_tensor->data<T>();
auto* bv_data = bv_tensor->data<T>();
auto combined_bias_dims = phi::make_ddim({3, num_head, dim_head});
phi::DenseTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
dev_ctx->Alloc<T>(&tmp_combined_bias_tensor);
auto* tmp_combined_bias_data = tmp_combined_bias_tensor.data<T>();
size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(T) * bias_size); memcpy(tmp_combined_bias_data, bq_data, sizeof(T) * bias_size);
memcpy(tmp_combined_bias_data + bias_size, bk_data, sizeof(T) * bias_size); memcpy(tmp_combined_bias_data + bias_size, bk_data, sizeof(T) * bias_size);
memcpy( memcpy(
tmp_combined_bias_data + 2 * bias_size, bv_data, sizeof(T) * bias_size); tmp_combined_bias_data + 2 * bias_size, bv_data, sizeof(T) * bias_size);
bq_tensor->Resize(combined_bias_dims); bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data = dev_ctx->Alloc<T>(bq_tensor);
bq_tensor->mutable_data<T>(platform::CPUPlace()); auto* new_combined_bias_data = bq_tensor->data<T>();
memcpy(new_combined_bias_data, memcpy(new_combined_bias_data,
tmp_combined_bias_data, tmp_combined_bias_data,
sizeof(T) * bq_tensor->numel()); sizeof(T) * bq_tensor->numel());
} }
inline void QKVWeightsBiasProcess(phi::DenseTensor* wq_tensor,
phi::DenseTensor* wk_tensor,
phi::DenseTensor* wv_tensor,
phi::DenseTensor* bq_tensor,
phi::DenseTensor* bk_tensor,
phi::DenseTensor* bv_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
switch (wq_tensor->dtype()) {
case paddle::experimental::DataType::FLOAT16:
QKVWeightsProcess<platform::float16>(
wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::FLOAT32:
QKVWeightsProcess<float>(
wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::INT8:
QKVWeightsProcess<int8_t>(
wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32/fp16/int8."));
break;
}
switch (bq_tensor->dtype()) {
case paddle::experimental::DataType::FLOAT16:
QKVBiasProcess<platform::float16>(
bq_tensor, bk_tensor, bv_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::FLOAT32:
QKVBiasProcess<float>(
bq_tensor, bk_tensor, bv_tensor, num_head, dim_head, dim_embed);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported bias dtype. "
"we now only support fp32/fp16."));
break;
}
}
template <typename T>
inline void QKVWeightsProcessFuseQKV(phi::DenseTensor* qkv_w_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()));
auto* qkv_w_data = qkv_w_tensor->data<T>();
auto transpose_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed});
phi::DenseTensor tmp_transpose_w_tensor;
tmp_transpose_w_tensor.Resize(transpose_w_dims);
dev_ctx->Alloc<T>(&tmp_transpose_w_tensor);
auto* tmp_transpose_w_data = tmp_transpose_w_tensor.data<T>();
// transpose qkv matmul Y to QKVWeights
for (int i = 0; i < 3; i++) {
for (int j = 0; j < num_head; j++) {
for (int k = 0; k < dim_head; k++) {
for (int l = 0; l < dim_embed; l++) {
int out_idx = i * num_head * dim_head * dim_embed +
j * dim_head * dim_embed + k * dim_embed + l;
int in_idx =
l * num_head * 3 * dim_head + j * 3 * dim_head + i * dim_head + k;
tmp_transpose_w_data[out_idx] = qkv_w_data[in_idx];
}
}
}
}
qkv_w_tensor->Resize(transpose_w_dims);
dev_ctx->Alloc<T>(qkv_w_tensor);
auto* new_transpose_w_data = qkv_w_tensor->data<T>();
memcpy(new_transpose_w_data,
tmp_transpose_w_data,
sizeof(T) * qkv_w_tensor->numel());
}
template <typename T>
inline void QKVBiasProcessFuseQKV(phi::DenseTensor* qkv_b_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()));
auto* qkv_b_data = qkv_b_tensor->data<T>();
auto transpose_b_dims = phi::make_ddim({3, num_head, dim_head});
phi::DenseTensor tmp_transpose_b_tensor;
tmp_transpose_b_tensor.Resize(transpose_b_dims);
dev_ctx->Alloc<T>(&tmp_transpose_b_tensor);
auto* tmp_transpose_b_data = tmp_transpose_b_tensor.data<T>();
// transpose qkv elemenwise_add Y to QKVBias
for (int i = 0; i < 3; i++) {
for (int j = 0; j < num_head; j++) {
for (int k = 0; k < dim_head; k++) {
int out_idx = i * num_head * dim_head + j * dim_head + k;
int in_idx = j * 3 * dim_head + i * dim_head + k;
tmp_transpose_b_data[out_idx] = qkv_b_data[in_idx];
}
}
}
qkv_b_tensor->Resize({3, num_head, dim_head});
dev_ctx->Alloc<T>(qkv_b_tensor);
auto* new_transpose_b_data = qkv_b_tensor->data<T>();
memcpy(new_transpose_b_data,
tmp_transpose_b_data,
sizeof(T) * qkv_b_tensor->numel());
}
inline void QKVWeightsBiasProcessFuseQKV(phi::DenseTensor* qkv_w_tensor,
phi::DenseTensor* qkv_b_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
switch (qkv_w_tensor->dtype()) {
case paddle::experimental::DataType::FLOAT16:
QKVWeightsProcessFuseQKV<platform::float16>(
qkv_w_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::FLOAT32:
QKVWeightsProcessFuseQKV<float>(
qkv_w_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::INT8:
QKVWeightsProcessFuseQKV<int8_t>(
qkv_w_tensor, num_head, dim_head, dim_embed);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32/fp16/int8."));
break;
}
switch (qkv_b_tensor->dtype()) {
case paddle::experimental::DataType::FLOAT16:
QKVBiasProcessFuseQKV<platform::float16>(
qkv_b_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::FLOAT32:
QKVBiasProcessFuseQKV<float>(qkv_b_tensor, num_head, dim_head, dim_embed);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported bias dtype. "
"we now only support fp32/fp16."));
break;
}
}
// Just use for fused_multi_transformer_int8
inline void TransposeWeights(phi::DenseTensor* weight_tensor) {
auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()));
int m = weight_tensor->dims()[0];
int n = weight_tensor->dims()[1];
phi::DenseTensor tmp_weight_tensor;
tmp_weight_tensor.Resize({n, m});
dev_ctx->Alloc<int8_t>(&tmp_weight_tensor);
auto tmp_weight_data = tmp_weight_tensor.data<int8_t>();
auto weight_data = weight_tensor->data<int8_t>();
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
int in_idx = i * n + j;
int out_idx = j * m + i;
tmp_weight_data[out_idx] = weight_data[in_idx];
}
}
weight_tensor->Resize({n, m});
dev_ctx->Alloc<int8_t>(weight_tensor);
auto new_weight_data = weight_tensor->data<int8_t>();
memcpy(new_weight_data, tmp_weight_data, sizeof(int8_t) * m * n);
}
inline Node* CreatePersistableVarNode(Graph* graph, const std::string& name) {
auto var_desc = VarDesc(name);
var_desc.SetDataType(framework::proto::VarType::FP32);
var_desc.SetPersistable(true);
auto node = graph->CreateVarNode(&var_desc);
return node;
}
int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
bool enable_int8 = graph->Get<bool>("enable_int8");
if (enable_int8) {
VLOG(3) << "FusedMultiTransformerEncoderPass with int8";
} else {
VLOG(3) << "FusedMultiTransformerEncoderPass with fp";
}
// Create pattern.
patterns::FusedMultiTransformerEncoderPattern fused_multi_transformer_pattern(
pattern, name_scope);
fused_multi_transformer_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0,
Node* layer_norm,
Node* layer_norm_scale,
Node* layer_norm_bias,
Node* layer_norm_mean,
Node* layer_norm_variance,
Node* matmul0,
Node* matmul0_w,
Node* matmul1_w,
Node* matmul2_w,
Node* eltadd0_b,
Node* eltadd1_b,
Node* eltadd2_b,
Node* transpose2_1_out,
Node* transpose2_2_out,
Node* eltadd_qk_b,
Node* reshape2_0,
Node* matmul_linear,
Node* matmul_linear_w,
Node* eltadd_linear_b,
Node* ffn_layer_norm,
Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance,
Node* ffn_matmul0,
Node* ffn_matmul0_w,
Node* ffn_matmul1,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_act,
Node* ffn_layer_norm_out) {
auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op();
auto* ffn_matmul_0_op = ffn_matmul0->Op();
auto* ffn_matmul_1_op = ffn_matmul1->Op();
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto ln_scale_name = layer_norm_scale->Name();
auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.'));
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2;
auto* wq_tensor =
scope->FindVar(matmul0_w->Name())->GetMutable<phi::DenseTensor>();
auto* wk_tensor =
scope->FindVar(matmul1_w->Name())->GetMutable<phi::DenseTensor>();
auto* wv_tensor =
scope->FindVar(matmul2_w->Name())->GetMutable<phi::DenseTensor>();
auto* bq_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<phi::DenseTensor>();
auto* bk_tensor =
scope->FindVar(eltadd1_b->Name())->GetMutable<phi::DenseTensor>();
auto* bv_tensor =
scope->FindVar(eltadd2_b->Name())->GetMutable<phi::DenseTensor>();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
// num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0]
// 2. calculate num_head according to wq_tensor.shape[1] and dim_head
auto reshape_desc = reshape2_0->Op();
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3);
auto* layer_norm_bias_tensor =
scope->FindVar(layer_norm_bias->Name())->GetMutable<phi::DenseTensor>();
int dim_embed = layer_norm_bias_tensor->dims()[0];
int num_head = wq_tensor->dims()[1] / dim_head;
QKVWeightsBiasProcess(wq_tensor,
wk_tensor,
wv_tensor,
bq_tensor,
bk_tensor,
bv_tensor,
num_head,
dim_head,
dim_embed);
if (enable_int8) {
auto* out_linear_w_tensor = scope->FindVar(matmul_linear_w->Name())
->GetMutable<phi::DenseTensor>();
auto* ffn0_w_tensor =
scope->FindVar(ffn_matmul0_w->Name())->GetMutable<phi::DenseTensor>();
auto* ffn1_w_tensor =
scope->FindVar(ffn_matmul1_w->Name())->GetMutable<phi::DenseTensor>();
TransposeWeights(out_linear_w_tensor);
TransposeWeights(ffn0_w_tensor);
TransposeWeights(ffn1_w_tensor);
}
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto* combined_w_desc = matmul0_w->Var();
combined_w_desc->SetShape({3, num_head, dim_head, dim_embed});
combined_w_desc->SetPersistable(true);
auto* combined_bias_desc = eltadd0_b->Var();
combined_bias_desc->SetShape({3, num_head, dim_head});
combined_bias_desc->SetPersistable(true);
scope->EraseVars({matmul1_w->Name(), matmul2_w->Name()});
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
// create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType(enable_int8
? "fused_multi_transformer_int8"
: "fused_multi_transformer");
// 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
// pre-LayerNorm input
fused_multi_transformer_op_desc.SetInput("LnScale",
{layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("LnBias",
{layer_norm_bias->Name()});
// QKV computation input
fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()});
// CacheKV input
VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx));
// FIXME: only support max_seq_len <= 1024
cache_kv_desc.SetDataType(
framework::TransToProtoVarType(bq_tensor->dtype()));
cache_kv_desc.SetPersistable(false);
auto* cache_kv = graph->CreateVarNode(&cache_kv_desc);
OpDesc fill_const_op_desc(layer_norm->Op()->Block());
fill_const_op_desc.SetType("fill_constant_batch_size_like");
fill_const_op_desc.SetInput("Input", {input0->Name()});
fill_const_op_desc.SetOutput("Out", {cache_kv->Name()});
std::vector<int> shape = {2, -1, num_head, 1024, dim_head};
fill_const_op_desc.SetAttr("shape", shape);
fill_const_op_desc.SetAttr("input_dim_idx", 0);
fill_const_op_desc.SetAttr("output_dim_idx", 1);
fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr(
"dtype",
static_cast<int>(framework::TransToProtoVarType(bq_tensor->dtype())));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
{matmul_linear_w->Name()});
fused_multi_transformer_op_desc.SetInput("OutLinearBias",
{eltadd_linear_b->Name()});
// Feed Forward input
fused_multi_transformer_op_desc.SetInput("FFNLnScale",
{ffn_layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("FFNLnBias",
{ffn_layer_norm_bias->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Weight",
{ffn_matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Bias",
{ffn_eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Weight",
{ffn_matmul1_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Bias",
{ffn_eltadd1_b->Name()});
// 2. Output setting
fused_multi_transformer_op_desc.SetOutput("Out",
{ffn_layer_norm_out->Name()});
fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv->Name()});
// Attribute setting
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", false);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
fused_multi_transformer_op_desc.SetAttr("is_test", true);
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
fused_multi_transformer_op_desc.SetAttr("act_method",
{ffn_act->Op()->Type()});
// Quantization attribute/Input
if (enable_int8) {
auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()));
// Set input scale
std::string qkv_input_name = matmul0_op->Input("X")[0];
auto qkv_in_scale = PADDLE_GET_CONST(
float, matmul0_op->GetAttr("Input_scale_" + qkv_input_name));
std::string out_linear_input_name = matmul_linear_op->Input("X")[0];
auto out_linear_in_scale = PADDLE_GET_CONST(
float,
matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name));
std::string ffn0_input_name = ffn_matmul_0_op->Input("X")[0];
auto ffn0_in_scale = PADDLE_GET_CONST(
float, ffn_matmul_0_op->GetAttr("Input_scale_" + ffn0_input_name));
std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0];
auto ffn1_in_scale = PADDLE_GET_CONST(
float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name));
// Calc outscale and Set them
auto qkv_weight_scale =
PADDLE_GET_CONST(float, matmul0_op->GetAttr("weight_scale"));
auto out_weight_scale =
PADDLE_GET_CONST(float, matmul_linear_op->GetAttr("weight_scale"));
auto ffn0_weight_scale =
PADDLE_GET_CONST(float, ffn_matmul_0_op->GetAttr("weight_scale"));
auto ffn1_weight_scale =
PADDLE_GET_CONST(float, ffn_matmul_1_op->GetAttr("weight_scale"));
auto qkv_out_scales = std::vector<float>(
3 * dim_embed, (qkv_weight_scale / 127.0f) * (qkv_in_scale / 127.0f));
auto out_out_scales = std::vector<float>(
dim_embed,
(out_weight_scale / 127.0f) * (out_linear_in_scale / 127.0f));
auto ffn0_out_scales = std::vector<float>(
4 * dim_embed,
(ffn0_weight_scale / 127.0f) * (ffn0_in_scale / 127.0f));
auto ffn1_out_scales = std::vector<float>(
dim_embed, (ffn1_weight_scale / 127.0f) * (ffn1_in_scale / 127.0f));
// Inverse input scale
qkv_in_scale = 1.0f / qkv_in_scale;
out_linear_in_scale = 1.0f / out_linear_in_scale;
ffn0_in_scale = 1.0f / ffn0_in_scale;
ffn1_in_scale = 1.0f / ffn1_in_scale;
fused_multi_transformer_op_desc.SetAttr("qkv_in_scale",
std::vector<float>{qkv_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"out_linear_in_scale", std::vector<float>{out_linear_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn1_in_scale", std::vector<float>{ffn0_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn2_in_scale", std::vector<float>{ffn1_in_scale});
auto qkv_out_scale_var = scope->Var(matmul0_w->Name() + "_out_scale");
auto out_out_scale_var =
scope->Var(matmul_linear_w->Name() + "_out_scale");
auto ffn0_out_scale_var =
scope->Var(ffn_matmul0_w->Name() + "_out_scale");
auto ffn1_out_scale_var =
scope->Var(ffn_matmul1_w->Name() + "_out_scale");
auto* qkv_out_scale_tensor =
qkv_out_scale_var->GetMutable<phi::DenseTensor>();
qkv_out_scale_tensor->Resize({3 * dim_embed});
dev_ctx->Alloc<float>(qkv_out_scale_tensor);
auto qkv_out_scale_data = qkv_out_scale_tensor->data<float>();
memcpy(qkv_out_scale_data,
qkv_out_scales.data(),
qkv_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"QKVOutScale", {matmul0_w->Name() + "_out_scale"});
auto* out_out_scale_tensor =
out_out_scale_var->GetMutable<phi::DenseTensor>();
out_out_scale_tensor->Resize({dim_embed});
dev_ctx->Alloc<float>(out_out_scale_tensor);
auto out_out_scale_data = out_out_scale_tensor->data<float>();
memcpy(out_out_scale_data,
out_out_scales.data(),
out_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"});
auto* ffn0_out_scale_tensor =
ffn0_out_scale_var->GetMutable<phi::DenseTensor>();
ffn0_out_scale_tensor->Resize({4 * dim_embed});
dev_ctx->Alloc<float>(ffn0_out_scale_tensor);
auto ffn0_out_scale_data = ffn0_out_scale_tensor->data<float>();
memcpy(ffn0_out_scale_data,
ffn0_out_scales.data(),
ffn0_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"});
auto* ffn1_out_scale_tensor =
ffn1_out_scale_var->GetMutable<phi::DenseTensor>();
ffn1_out_scale_tensor->Resize({dim_embed});
dev_ctx->Alloc<float>(ffn1_out_scale_tensor);
auto ffn1_out_scale_data = ffn1_out_scale_tensor->data<float>();
memcpy(ffn1_out_scale_data,
ffn1_out_scales.data(),
ffn1_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"});
}
auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc);
if (enable_int8) {
auto qkv_out_scale_node =
CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale");
auto out_out_scale_node = CreatePersistableVarNode(
graph, matmul_linear_w->Name() + "_out_scale");
auto ffn0_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale");
auto ffn1_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale");
IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer);
}
IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
IR_NODE_LINK_TO(input0, fill_const_op);
IR_NODE_LINK_TO(fill_const_op, cache_kv);
IR_NODE_LINK_TO(cache_kv, fused_multi_transformer);
IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, ffn_layer_norm_out);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_encoder pass in "
"op compat failed.";
return;
}
VLOG(4) << "handle MultiTransformer encoder fuse";
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_bias, layer_norm_bias, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_mean, layer_norm_mean, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_out, layer_norm_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_out, matmul0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0_out, reshape2_0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0, transpose2_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0_out, transpose2_0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1, matmul1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1_out, matmul1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1_w, matmul1_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1, reshape2_1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1_out, reshape2_1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1, transpose2_1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1_out, transpose2_1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_q, scale_q, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_q_out, scale_q_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2, matmul2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2_out, matmul2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2_w, matmul2_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2, reshape2_2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2_out, reshape2_2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2, transpose2_2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2_out, transpose2_2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
attention_output, attention_output, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_layer_norm, ffn_layer_norm, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0, ffn_matmul0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_out, ffn_matmul0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_act, ffn_act, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_act_out, ffn_act_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_out, ffn_matmul1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_pattern)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1, eltadd1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1_b, eltadd1_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1_out, eltadd1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2, eltadd2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2_b, eltadd2_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2_out, eltadd2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv_out, matmul_qkv_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv, reshape2_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv_out, reshape2_qkv_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_qkv, transpose2_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear, matmul_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear_w, matmul_linear_w, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear_out, matmul_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear, eltadd_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_pattern)
fuse_creater(input0,
layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
matmul0,
matmul0_w,
matmul1_w,
matmul2_w,
eltadd0_b,
eltadd1_b,
eltadd2_b,
transpose2_1_out,
transpose2_2_out,
eltadd_qk_b,
reshape2_0,
matmul_linear,
matmul_linear_w,
eltadd_linear_b,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_matmul0,
ffn_matmul0_w,
ffn_matmul1,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_act,
ffn_layer_norm_out);
std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_mean,
layer_norm_variance,
layer_norm_out,
matmul0,
matmul1,
matmul2,
matmul0_out,
matmul1_out,
matmul2_out,
eltadd0,
eltadd1,
eltadd2,
eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out,
transpose2_1_out,
transpose2_2_out,
scale_q,
scale_q_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
reshape2_qkv,
transpose2_qkv,
transpose2_qkv_out,
matmul_linear,
matmul_linear_out,
eltadd_linear,
eltadd_linear_out,
eltadd_out,
ffn_layer_norm,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_matmul0,
ffn_matmul1,
ffn_matmul0_out,
ffn_matmul1_out,
ffn_eltadd0,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_act,
ffn_act_out,
ffn_output,
ffn_eltadd_out});
inline void QKVWeightsBiasProcess(phi::DenseTensor* wq_tensor, // Remove unneeded nodes.
phi::DenseTensor* wk_tensor, GraphSafeRemoveNodes(graph, marked_nodes);
phi::DenseTensor* wv_tensor, ++fusion_count;
phi::DenseTensor* bq_tensor, };
phi::DenseTensor* bk_tensor, gpd(graph, handler);
phi::DenseTensor* bv_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
switch (wq_tensor->dtype()) {
case paddle::experimental::DataType::FLOAT16:
QKVWeightsProcess<platform::float16>(
wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::FLOAT32:
QKVWeightsProcess<float>(
wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::INT8:
QKVWeightsProcess<int8_t>(
wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32/fp16/int8."));
break;
}
switch (bq_tensor->dtype()) {
case paddle::experimental::DataType::FLOAT16:
QKVBiasProcess<platform::float16>(
bq_tensor, bk_tensor, bv_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::FLOAT32:
QKVBiasProcess<float>(
bq_tensor, bk_tensor, bv_tensor, num_head, dim_head, dim_embed);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported bias dtype. "
"we now only support fp32/fp16."));
break;
}
}
template <typename T> return fusion_count;
inline void QKVWeightsProcessFuseQKV(phi::DenseTensor* qkv_w_tensor, }
const int num_head,
const int dim_head,
const int dim_embed) {
auto* qkv_w_data = qkv_w_tensor->data<T>();
auto transpose_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed});
phi::DenseTensor tmp_transpose_w_tensor; void FusedMultiTransformerEncoderPass::ApplyImpl(Graph* graph) const {
tmp_transpose_w_tensor.Resize(transpose_w_dims); FusePassBase::Init(name_scope_, graph);
auto* tmp_transpose_w_data = auto* scope = param_scope();
tmp_transpose_w_tensor.mutable_data<T>(platform::CPUPlace()); PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"During the multi_transformer pass, The scope should not be null."));
// transpose qkv matmul Y to QKVWeights int fusion_count = BuildFusion(graph, name_scope_, scope);
for (int i = 0; i < 3; i++) { if (fusion_count > 0) {
for (int j = 0; j < num_head; j++) { graph->Set(kFusedMultiTransformerEncoderPass, new bool(true));
for (int k = 0; k < dim_head; k++) { graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count));
for (int l = 0; l < dim_embed; l++) {
int out_idx = i * num_head * dim_head * dim_embed +
j * dim_head * dim_embed + k * dim_embed + l;
int in_idx =
l * num_head * 3 * dim_head + j * 3 * dim_head + i * dim_head + k;
tmp_transpose_w_data[out_idx] = qkv_w_data[in_idx];
}
}
}
} }
AddStatis(fusion_count);
qkv_w_tensor->Resize(transpose_w_dims);
auto* new_transpose_w_data =
qkv_w_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(new_transpose_w_data,
tmp_transpose_w_data,
sizeof(T) * qkv_w_tensor->numel());
} }
template <typename T> FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() {
inline void QKVBiasProcessFuseQKV(phi::DenseTensor* qkv_b_tensor, AddOpCompat(OpCompat("layer_norm"))
const int num_head, .AddInput("X")
const int dim_head, .IsTensor()
const int dim_embed) { .End()
auto* qkv_b_data = qkv_b_tensor->data<T>(); .AddInput("Scale")
auto transpose_b_dims = phi::make_ddim({3, num_head, dim_head}); .IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
phi::DenseTensor tmp_transpose_b_tensor; AddOpCompat(OpCompat("matmul_v2"))
tmp_transpose_b_tensor.Resize(transpose_b_dims); .AddInput("X") // the shape shoule be (B, S, N*H)
auto* tmp_transpose_b_data = .IsTensor()
tmp_transpose_b_tensor.mutable_data<T>(platform::CPUPlace()); .End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
// transpose qkv elemenwise_add Y to QKVBias AddOpCompat(OpCompat("elementwise_add"))
for (int i = 0; i < 3; i++) { .AddInput("X")
for (int j = 0; j < num_head; j++) { .IsTensor()
for (int k = 0; k < dim_head; k++) { .End()
int out_idx = i * num_head * dim_head + j * dim_head + k; .AddInput("Y")
int in_idx = j * 3 * dim_head + i * dim_head + k; .IsTensor()
tmp_transpose_b_data[out_idx] = qkv_b_data[in_idx]; .End()
} .AddOutput("Out")
} .IsTensor()
} .End()
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
qkv_b_tensor->Resize({3, num_head, dim_head}); AddOpCompat(OpCompat("scale"))
auto* new_transpose_b_data = .AddInput("X")
qkv_b_tensor->mutable_data<T>(platform::CPUPlace()); .IsTensor()
memcpy(new_transpose_b_data, .End()
tmp_transpose_b_data, .AddOutput("Out")
sizeof(T) * qkv_b_tensor->numel()); .IsTensor()
} .End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();
inline void QKVWeightsBiasProcessFuseQKV(phi::DenseTensor* qkv_w_tensor, AddOpCompat(OpCompat("softmax"))
phi::DenseTensor* qkv_b_tensor, .AddInput("X")
const int num_head, .IsTensor()
const int dim_head, .End()
const int dim_embed) { .AddOutput("Out")
switch (qkv_w_tensor->dtype()) { .IsTensor()
case paddle::experimental::DataType::FLOAT16: .End()
QKVWeightsProcessFuseQKV<platform::float16>( .AddAttr("axis")
qkv_w_tensor, num_head, dim_head, dim_embed); .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
break; .End();
case paddle::experimental::DataType::FLOAT32:
QKVWeightsProcessFuseQKV<float>(
qkv_w_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::INT8:
QKVWeightsProcessFuseQKV<int8_t>(
qkv_w_tensor, num_head, dim_head, dim_embed);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32/fp16/int8."));
break;
}
switch (qkv_b_tensor->dtype()) {
case paddle::experimental::DataType::FLOAT16:
QKVBiasProcessFuseQKV<platform::float16>(
qkv_b_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::FLOAT32:
QKVBiasProcessFuseQKV<float>(qkv_b_tensor, num_head, dim_head, dim_embed);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported bias dtype. "
"we now only support fp32/fp16."));
break;
}
}
// Just use for fused_multi_transformer_int8 AddOpCompat(OpCompat("gelu"))
inline void TransposeWeights(phi::DenseTensor* weight_tensor) { .AddInput("X")
int m = weight_tensor->dims()[0]; .IsTensor()
int n = weight_tensor->dims()[1]; .End()
phi::DenseTensor tmp_weight_tensor; .AddOutput("Out")
auto tmp_weight_data = .IsTensor()
tmp_weight_tensor.mutable_data<int8_t>({n, m}, platform::CPUPlace()); .End()
auto weight_data = weight_tensor->data<int8_t>(); .AddAttr("approximate")
for (int i = 0; i < m; ++i) { .IsType<bool>()
for (int j = 0; j < n; ++j) { .End();
int in_idx = i * n + j;
int out_idx = j * m + i;
tmp_weight_data[out_idx] = weight_data[in_idx];
}
}
weight_tensor->Resize({n, m});
auto new_weight_data =
weight_tensor->mutable_data<int8_t>(platform::CPUPlace());
memcpy(new_weight_data, tmp_weight_data, sizeof(int8_t) * m * n);
}
inline Node* CreatePersistableVarNode(Graph* graph, const std::string& name) { AddOpCompat(OpCompat("relu"))
auto var_desc = VarDesc(name); .AddInput("X")
var_desc.SetDataType(framework::proto::VarType::FP32); .IsTensor()
var_desc.SetPersistable(true); .End()
auto node = graph->CreateVarNode(&var_desc); .AddOutput("Out")
return node; .IsTensor()
.End();
} }
int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
const std::string& name_scope, Graph* graph, const std::string& name_scope, Scope* scope) const {
Scope* scope) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
bool enable_int8 = graph->Get<bool>("enable_int8"); bool enable_int8 = graph->Get<bool>("enable_int8");
if (enable_int8) { if (enable_int8) {
VLOG(3) << "FusedMultiTransformerEncoderPass with int8"; VLOG(3) << "FusedMultiTransformerEncoderFuseQKVPass with int8";
} else { } else {
VLOG(3) << "FusedMultiTransformerEncoderPass with fp"; VLOG(3) << "FusedMultiTransformerEncoderFuseQKVPass with fp";
} }
// Create pattern. // Create pattern.
patterns::FusedMultiTransformerEncoderPattern fused_multi_transformer_pattern( patterns::FusedMultiTransformerEncoderFuseQKVPattern
pattern, name_scope); fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope);
fused_multi_transformer_pattern(); fused_multi_transformer_fuse_qkv_pattern();
// Create New OpDesc // Create New OpDesc
auto fuse_creater = [&](Node* input0, auto fuse_creater = [&](Node* input0,
...@@ -1316,13 +2578,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1316,13 +2578,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node* layer_norm_variance, Node* layer_norm_variance,
Node* matmul0, Node* matmul0,
Node* matmul0_w, Node* matmul0_w,
Node* matmul1_w,
Node* matmul2_w,
Node* eltadd0_b, Node* eltadd0_b,
Node* eltadd1_b, Node* split0_k_out,
Node* eltadd2_b, Node* split0_v_out,
Node* transpose2_1_out,
Node* transpose2_2_out,
Node* eltadd_qk_b, Node* eltadd_qk_b,
Node* reshape2_0, Node* reshape2_0,
Node* matmul_linear, Node* matmul_linear,
...@@ -1340,6 +2598,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1340,6 +2598,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_act,
Node* ffn_output) { Node* ffn_output) {
auto* matmul0_op = matmul0->Op(); auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op(); auto* matmul_linear_op = matmul_linear->Op();
...@@ -1355,43 +2614,28 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1355,43 +2614,28 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1); auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2; int layer_idx = atoi(ln_idx_str.c_str()) / 2;
auto* wq_tensor = auto* qkv_w_tensor =
scope->FindVar(matmul0_w->Name())->GetMutable<phi::DenseTensor>(); scope->FindVar(matmul0_w->Name())->GetMutable<phi::DenseTensor>();
auto* wk_tensor = auto* qkv_b_tensor =
scope->FindVar(matmul1_w->Name())->GetMutable<phi::DenseTensor>();
auto* wv_tensor =
scope->FindVar(matmul2_w->Name())->GetMutable<phi::DenseTensor>();
auto* bq_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<phi::DenseTensor>(); scope->FindVar(eltadd0_b->Name())->GetMutable<phi::DenseTensor>();
auto* bk_tensor =
scope->FindVar(eltadd1_b->Name())->GetMutable<phi::DenseTensor>();
auto* bv_tensor =
scope->FindVar(eltadd2_b->Name())->GetMutable<phi::DenseTensor>();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on // NOTE(minghaoBD): to make it compatible with strucutured pruning on
// num_head dimension: // num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from // 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0] // layer_norm_bias.shape[0]
// 2. calculate num_head according to wq_tensor.shape[1] and dim_head // 2. calculate num_head according to wqkv_tensor.shape[1]/3 and dim_head
auto reshape_desc = reshape2_0->Op(); auto reshape_desc = reshape2_0->Op();
int dim_head = int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape")) PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3); .at(3) /
3; // 3 for qkv
auto* layer_norm_bias_tensor = auto* layer_norm_bias_tensor =
scope->FindVar(layer_norm_bias->Name())->GetMutable<phi::DenseTensor>(); scope->FindVar(layer_norm_bias->Name())->GetMutable<phi::DenseTensor>();
int dim_embed = layer_norm_bias_tensor->dims()[0]; int dim_embed = layer_norm_bias_tensor->dims()[0];
int num_head = wq_tensor->dims()[1] / dim_head; int num_head = qkv_w_tensor->dims()[1] / 3 / dim_head;
QKVWeightsBiasProcess(wq_tensor, QKVWeightsBiasProcessFuseQKV(
wk_tensor, qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
wv_tensor,
bq_tensor,
bk_tensor,
bv_tensor,
num_head,
dim_head,
dim_embed);
if (enable_int8) { if (enable_int8) {
auto* out_linear_w_tensor = scope->FindVar(matmul_linear_w->Name()) auto* out_linear_w_tensor = scope->FindVar(matmul_linear_w->Name())
...@@ -1406,18 +2650,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1406,18 +2650,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
TransposeWeights(ffn1_w_tensor); TransposeWeights(ffn1_w_tensor);
} }
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto* combined_w_desc = matmul0_w->Var();
combined_w_desc->SetShape({3, num_head, dim_head, dim_embed});
combined_w_desc->SetPersistable(true);
auto* combined_bias_desc = eltadd0_b->Var();
combined_bias_desc->SetShape({3, num_head, dim_head});
combined_bias_desc->SetPersistable(true);
scope->EraseVars({matmul1_w->Name(), matmul2_w->Name()});
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
// create fused_multi_transformer // create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType(enable_int8 fused_multi_transformer_op_desc.SetType(enable_int8
...@@ -1442,7 +2674,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1442,7 +2674,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx)); VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx));
// FIXME: only support max_seq_len <= 1024 // FIXME: only support max_seq_len <= 1024
cache_kv_desc.SetDataType( cache_kv_desc.SetDataType(
framework::TransToProtoVarType(bq_tensor->dtype())); framework::TransToProtoVarType(qkv_b_tensor->dtype()));
cache_kv_desc.SetPersistable(false); cache_kv_desc.SetPersistable(false);
auto* cache_kv = graph->CreateVarNode(&cache_kv_desc); auto* cache_kv = graph->CreateVarNode(&cache_kv_desc);
...@@ -1455,9 +2687,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1455,9 +2687,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
fill_const_op_desc.SetAttr("input_dim_idx", 0); fill_const_op_desc.SetAttr("input_dim_idx", 0);
fill_const_op_desc.SetAttr("output_dim_idx", 1); fill_const_op_desc.SetAttr("output_dim_idx", 1);
fill_const_op_desc.SetAttr("value", 0); fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr( fill_const_op_desc.SetAttr("dtype",
"dtype", static_cast<int>(framework::TransToProtoVarType(
static_cast<int>(framework::TransToProtoVarType(bq_tensor->dtype()))); qkv_b_tensor->dtype())));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
...@@ -1490,12 +2722,17 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1490,12 +2722,17 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr( fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon")); "epsilon", layer_norm->Op()->GetAttr("epsilon"));
fused_multi_transformer_op_desc.SetAttr("act_method",
ffn_act->Op()->Type());
// output dropout attribute
fused_multi_transformer_op_desc.SetAttr("is_test", true); fused_multi_transformer_op_desc.SetAttr("is_test", true);
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
// Quantization attribute/Input // Quantization attribute/Input
if (enable_int8) { if (enable_int8) {
auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()));
// Set input scale // Set input scale
std::string qkv_input_name = matmul0_op->Input("X")[0]; std::string qkv_input_name = matmul0_op->Input("X")[0];
auto qkv_in_scale = PADDLE_GET_CONST( auto qkv_in_scale = PADDLE_GET_CONST(
...@@ -1512,6 +2749,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1512,6 +2749,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name)); float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name));
// Calc outscale and Set them // Calc outscale and Set them
// TODO(wufeisheng): Currently just match layer-wise weight scale, where
// channel-wise weight scale should also be surpported.
auto qkv_weight_scale = auto qkv_weight_scale =
PADDLE_GET_CONST(float, matmul0_op->GetAttr("weight_scale")); PADDLE_GET_CONST(float, matmul0_op->GetAttr("weight_scale"));
auto out_weight_scale = auto out_weight_scale =
...@@ -1555,36 +2794,44 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1555,36 +2794,44 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
auto ffn1_out_scale_var = auto ffn1_out_scale_var =
scope->Var(ffn_matmul1_w->Name() + "_out_scale"); scope->Var(ffn_matmul1_w->Name() + "_out_scale");
auto qkv_out_scale_data = auto* qkv_out_scale_tensor =
qkv_out_scale_var->GetMutable<phi::DenseTensor>() qkv_out_scale_var->GetMutable<phi::DenseTensor>();
->mutable_data<float>({3 * dim_embed}, platform::CPUPlace()); qkv_out_scale_tensor->Resize({3 * dim_embed});
dev_ctx->Alloc<float>(qkv_out_scale_tensor);
auto qkv_out_scale_data = qkv_out_scale_tensor->data<float>();
memcpy(qkv_out_scale_data, memcpy(qkv_out_scale_data,
qkv_out_scales.data(), qkv_out_scales.data(),
qkv_out_scales.size() * sizeof(float)); qkv_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput( fused_multi_transformer_op_desc.SetInput(
"QKVOutScale", {matmul0_w->Name() + "_out_scale"}); "QKVOutScale", {matmul0_w->Name() + "_out_scale"});
auto out_out_scale_data = auto* out_out_scale_tensor =
out_out_scale_var->GetMutable<phi::DenseTensor>() out_out_scale_var->GetMutable<phi::DenseTensor>();
->mutable_data<float>({dim_embed}, platform::CPUPlace()); out_out_scale_tensor->Resize({dim_embed});
dev_ctx->Alloc<float>(out_out_scale_tensor);
auto out_out_scale_data = out_out_scale_tensor->data<float>();
memcpy(out_out_scale_data, memcpy(out_out_scale_data,
out_out_scales.data(), out_out_scales.data(),
out_out_scales.size() * sizeof(float)); out_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput( fused_multi_transformer_op_desc.SetInput(
"OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"}); "OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"});
auto ffn0_out_scale_data = auto* ffn0_out_scale_tensor =
ffn0_out_scale_var->GetMutable<phi::DenseTensor>() ffn0_out_scale_var->GetMutable<phi::DenseTensor>();
->mutable_data<float>({4 * dim_embed}, platform::CPUPlace()); ffn0_out_scale_tensor->Resize({4 * dim_embed});
dev_ctx->Alloc<float>(ffn0_out_scale_tensor);
auto ffn0_out_scale_data = ffn0_out_scale_tensor->data<float>();
memcpy(ffn0_out_scale_data, memcpy(ffn0_out_scale_data,
ffn0_out_scales.data(), ffn0_out_scales.data(),
ffn0_out_scales.size() * sizeof(float)); ffn0_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput( fused_multi_transformer_op_desc.SetInput(
"FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"}); "FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"});
auto ffn1_out_scale_data = auto* ffn1_out_scale_tensor =
ffn1_out_scale_var->GetMutable<phi::DenseTensor>() ffn1_out_scale_var->GetMutable<phi::DenseTensor>();
->mutable_data<float>({dim_embed}, platform::CPUPlace()); ffn1_out_scale_tensor->Resize({dim_embed});
dev_ctx->Alloc<float>(ffn1_out_scale_tensor);
auto ffn1_out_scale_data = ffn1_out_scale_tensor->data<float>();
memcpy(ffn1_out_scale_data, memcpy(ffn1_out_scale_data,
ffn1_out_scales.data(), ffn1_out_scales.data(),
ffn1_out_scales.size() * sizeof(float)); ffn1_out_scales.size() * sizeof(float));
...@@ -1641,27 +2888,11 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1641,27 +2888,11 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
auto while_Xs = while0->Op()->Input("X"); auto while_Xs = while0->Op()->Input("X");
while_Xs.erase( while_Xs.erase(
std::remove( std::remove(
std::begin(while_Xs), std::end(while_Xs), transpose2_1_out->Name()), std::begin(while_Xs), std::end(while_Xs), split0_k_out->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), transpose2_2_out->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), matmul1_w->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), matmul2_w->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), eltadd1_b->Name()),
std::end(while_Xs)); std::end(while_Xs));
while_Xs.erase( while_Xs.erase(
std::remove( std::remove(
std::begin(while_Xs), std::end(while_Xs), eltadd2_b->Name()), std::begin(while_Xs), std::end(while_Xs), split0_v_out->Name()),
std::end(while_Xs)); std::end(while_Xs));
while_Xs.emplace_back(cache_kv->Name()); while_Xs.emplace_back(cache_kv->Name());
while0->Op()->SetInput("X", while_Xs); while0->Op()->SetInput("X", while_Xs);
...@@ -1670,13 +2901,13 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1670,13 +2901,13 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
// 1. delete k, v // 1. delete k, v
// 2. add cache_kv // 2. add cache_kv
auto while_Outs = while0->Op()->Output("Out"); auto while_Outs = while0->Op()->Output("Out");
while_Outs.erase(std::remove(std::begin(while_Outs), while_Outs.erase(
std::end(while_Outs), std::remove(
transpose2_1_out->Name()), std::begin(while_Outs), std::end(while_Outs), split0_k_out->Name()),
std::end(while_Outs)); std::end(while_Outs));
while_Outs.erase(std::remove(std::begin(while_Outs), while_Outs.erase(
std::end(while_Outs), std::remove(
transpose2_2_out->Name()), std::begin(while_Outs), std::end(while_Outs), split0_v_out->Name()),
std::end(while_Outs)); std::end(while_Outs));
while_Outs.emplace_back(cache_kv->Name()); while_Outs.emplace_back(cache_kv->Name());
while0->Op()->SetOutput("Out", while_Outs); while0->Op()->SetOutput("Out", while_Outs);
...@@ -1684,213 +2915,200 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1684,213 +2915,200 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
// link CacheKV to while // link CacheKV to while
IR_NODE_LINK_TO(cache_kv, while0) IR_NODE_LINK_TO(cache_kv, while0)
// unlink origin KV output to while // unlink origin KV output to while
IR_NODE_UNLINK(transpose2_1_out, while0); IR_NODE_UNLINK(split0_k_out, while0);
IR_NODE_UNLINK(transpose2_2_out, while0); IR_NODE_UNLINK(split0_v_out, while0);
IR_NODE_UNLINK(while0, transpose2_1_out); IR_NODE_UNLINK(while0, split0_k_out);
IR_NODE_UNLINK(while0, transpose2_2_out); IR_NODE_UNLINK(while0, split0_v_out);
// unlink KV weight/bias to while after merged into Q weight/bias
IR_NODE_UNLINK(matmul1_w, while0);
IR_NODE_UNLINK(matmul2_w, while0);
IR_NODE_UNLINK(eltadd1_b, while0);
IR_NODE_UNLINK(eltadd2_b, while0);
}; };
int fusion_count{0}; int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
if (!IsCompat(subgraph, graph)) { if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_encoder pass in " LOG(WARNING) << "fused_multi_transformer_encoder_fuse_qkv "
"op compat failed."; "pass in op compat failed.";
return; return;
} }
VLOG(4) << "handle MultiTransformer encoder fuse"; VLOG(4) << "handle MultiTransformer encoder(Fuse-QKV) fuse";
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_bias, layer_norm_bias, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_mean, layer_norm_mean, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_out, layer_norm_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_out, matmul0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0_out, reshape2_0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0, transpose2_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0_out, transpose2_0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul1, matmul1, fused_multi_transformer_pattern); input0, input0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul1_out, matmul1_out, fused_multi_transformer_pattern); layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale,
layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias,
layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean,
layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out,
layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul1_w, matmul1_w, fused_multi_transformer_pattern); matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1, reshape2_1, fused_multi_transformer_pattern); matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1_out, reshape2_1_out, fused_multi_transformer_pattern); matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1, transpose2_1, fused_multi_transformer_pattern); reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out,
reshape2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1_out, transpose2_1_out, fused_multi_transformer_pattern); transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out,
transpose2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul2, matmul2, fused_multi_transformer_pattern); split0, split0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2_out, matmul2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2_w, matmul2_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2, reshape2_2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2_out, reshape2_2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2, transpose2_2, fused_multi_transformer_pattern); split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2_out, transpose2_2_out, fused_multi_transformer_pattern); split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
attention_output, attention_output, fused_multi_transformer_pattern) split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(while0, while0, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm,
ffn_layer_norm, ffn_layer_norm, fused_multi_transformer_pattern); ffn_layer_norm,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale, GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale, ffn_layer_norm_scale,
fused_multi_transformer_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias, GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias, ffn_layer_norm_bias,
fused_multi_transformer_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean, GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean, ffn_layer_norm_mean,
fused_multi_transformer_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance, GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance, ffn_layer_norm_variance,
fused_multi_transformer_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out, GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out, ffn_layer_norm_out,
fused_multi_transformer_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0, ffn_matmul0, fused_multi_transformer_pattern); ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out,
ffn_matmul0_out, ffn_matmul0_out, fused_multi_transformer_pattern); ffn_matmul0_out,
GET_IR_NODE_FROM_SUBGRAPH( fused_multi_transformer_fuse_qkv_pattern);
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_pattern); ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_pattern); ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern); ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out,
ffn_eltadd0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_pattern); ffn_act, ffn_act, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_pattern); ffn_act_out, ffn_act_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern); ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out,
ffn_matmul1_out, ffn_matmul1_out, fused_multi_transformer_pattern); ffn_matmul1_out,
GET_IR_NODE_FROM_SUBGRAPH( fused_multi_transformer_fuse_qkv_pattern);
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_pattern); ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_pattern); ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern); ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out,
ffn_eltadd1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out,
ffn_eltadd_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern) ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_pattern)
// nodes need be removed // nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_pattern); eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1, eltadd1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd1_b, eltadd1_b, fused_multi_transformer_pattern); eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd1_out, eltadd1_out, fused_multi_transformer_pattern); eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd2, eltadd2, fused_multi_transformer_pattern); matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2_b, eltadd2_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd2_out, eltadd2_out, fused_multi_transformer_pattern); matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_pattern); scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_pattern); scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_pattern); eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_pattern); eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_pattern); eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_pattern); softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out,
softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern); softmax_qk_out,
GET_IR_NODE_FROM_SUBGRAPH( fused_multi_transformer_fuse_qkv_pattern);
matmul_qkv, matmul_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv_out, matmul_qkv_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv, reshape2_qkv, fused_multi_transformer_pattern); matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out,
reshape2_qkv_out, reshape2_qkv_out, fused_multi_transformer_pattern); matmul_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
transpose2_qkv, transpose2_qkv, fused_multi_transformer_pattern); reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out,
reshape2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv,
transpose2_qkv,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out, transpose2_qkv_out,
fused_multi_transformer_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear, matmul_linear, fused_multi_transformer_pattern) matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w,
matmul_linear_w, matmul_linear_w, fused_multi_transformer_pattern) matmul_linear_w,
GET_IR_NODE_FROM_SUBGRAPH( fused_multi_transformer_fuse_qkv_pattern)
matmul_linear_out, matmul_linear_out, fused_multi_transformer_pattern) GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out,
GET_IR_NODE_FROM_SUBGRAPH( matmul_linear_out,
eltadd_linear, eltadd_linear, fused_multi_transformer_pattern) fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern) eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b,
eltadd_linear_b,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out,
eltadd_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern) eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_pattern) while0, while0, fused_multi_transformer_fuse_qkv_pattern)
fuse_creater(input0, fuse_creater(input0,
layer_norm, layer_norm,
...@@ -1900,13 +3118,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1900,13 +3118,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
layer_norm_variance, layer_norm_variance,
matmul0, matmul0,
matmul0_w, matmul0_w,
matmul1_w,
matmul2_w,
eltadd0_b, eltadd0_b,
eltadd1_b, split0_k_out,
eltadd2_b, split0_v_out,
transpose2_1_out,
transpose2_2_out,
eltadd_qk_b, eltadd_qk_b,
reshape2_0, reshape2_0,
matmul_linear, matmul_linear,
...@@ -1924,6 +3138,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1924,6 +3138,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
ffn_act,
ffn_output); ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm, std::unordered_set<const Node*> marked_nodes({layer_norm,
...@@ -1931,31 +3146,21 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1931,31 +3146,21 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
layer_norm_variance, layer_norm_variance,
layer_norm_out, layer_norm_out,
matmul0, matmul0,
matmul1,
matmul2,
matmul0_out, matmul0_out,
matmul1_out,
matmul2_out,
eltadd0, eltadd0,
eltadd1,
eltadd2,
eltadd0_out, eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0, reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out, reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0, transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out, transpose2_0_out,
transpose2_1_out, split0,
transpose2_2_out, split0_q_out,
split0_k_out,
split0_v_out,
matmul_qk, matmul_qk,
matmul_qk_out, matmul_qk_out,
scale_qk,
scale_qk_out,
eltadd_qk, eltadd_qk,
eltadd_qk_out, eltadd_qk_out,
softmax_qk, softmax_qk,
...@@ -1984,8 +3189,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1984,8 +3189,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
ffn_eltadd1, ffn_eltadd1,
ffn_eltadd0_out, ffn_eltadd0_out,
ffn_eltadd1_out, ffn_eltadd1_out,
ffn_gelu, ffn_act,
ffn_gelu_out, ffn_act_out,
ffn_eltadd_out}); ffn_eltadd_out});
// Remove unneeded nodes. // Remove unneeded nodes.
...@@ -1997,23 +3202,25 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1997,23 +3202,25 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
return fusion_count; return fusion_count;
} }
void FusedMultiTransformerEncoderPass::ApplyImpl(Graph* graph) const { void FusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
scope, scope,
platform::errors::Fatal( platform::errors::Fatal(
"During the multi_transformer pass, The scope should not be null.")); "During the fused_multi_transformer_encoder pass, "
"The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope); int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) { if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderPass, new bool(true)); graph->Set(kFusedMultiTransformerEncoderFuseQKVPass, new bool(true));
graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count)); graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count));
} }
AddStatis(fusion_count); AddStatis(fusion_count);
} }
FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() { FusedMultiTransformerEncoderFuseQKVPass::
FusedMultiTransformerEncoderFuseQKVPass() {
AddOpCompat(OpCompat("layer_norm")) AddOpCompat(OpCompat("layer_norm"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -2041,6 +3248,23 @@ FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() { ...@@ -2041,6 +3248,23 @@ FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() {
.IsNumGT(0) .IsNumGT(0)
.End(); .End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2")) AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H) .AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor() .IsTensor()
...@@ -2168,56 +3392,54 @@ FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() { ...@@ -2168,56 +3392,54 @@ FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() {
.End(); .End();
} }
int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( int MultiDevicesFusedMultiTransformerEncoderPass::BuildFusion(
Graph* graph, const std::string& name_scope, Scope* scope) const { Graph* graph, const std::string& name_scope, Scope* scope) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
bool enable_int8 = graph->Get<bool>("enable_int8");
if (enable_int8) {
VLOG(3) << "FusedMultiTransformerEncoderFuseQKVPass with int8";
} else {
VLOG(3) << "FusedMultiTransformerEncoderFuseQKVPass with fp";
}
// Create pattern. // Create pattern.
patterns::FusedMultiTransformerEncoderFuseQKVPattern patterns::MultiDevicesFusedMultiTransformerEncoderPattern
fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope); multi_devices_fused_multi_transformer_pattern(pattern, name_scope);
fused_multi_transformer_fuse_qkv_pattern(); multi_devices_fused_multi_transformer_pattern();
// Create New OpDesc // Create New OpDesc
auto fuse_creater = [&](Node* input0, auto fuse_creater = [&](Node* input0,
Node* c_identity,
Node* layer_norm, Node* layer_norm,
Node* layer_norm_scale, Node* layer_norm_scale,
Node* layer_norm_bias, Node* layer_norm_bias,
Node* layer_norm_mean, Node* layer_norm_mean,
Node* layer_norm_variance, Node* layer_norm_variance,
Node* matmul0,
Node* matmul0_w, Node* matmul0_w,
Node* matmul1_w,
Node* matmul2_w,
Node* eltadd0_b, Node* eltadd0_b,
Node* split0_k_out, Node* eltadd1_b,
Node* split0_v_out, Node* eltadd2_b,
Node* transpose2_1_out,
Node* transpose2_2_out,
Node* eltadd_qk_b, Node* eltadd_qk_b,
Node* reshape2_0, Node* reshape2_0,
Node* matmul_linear,
Node* matmul_linear_w, Node* matmul_linear_w,
Node* eltadd_linear_b, Node* eltadd_linear_b,
Node* while0,
Node* ffn_layer_norm, Node* ffn_layer_norm,
Node* ffn_layer_norm_scale, Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias, Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean, Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance, Node* ffn_layer_norm_variance,
Node* ffn_matmul0,
Node* ffn_matmul0_w, Node* ffn_matmul0_w,
Node* ffn_matmul1,
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_output) { Node* ffn_act,
auto* matmul0_op = matmul0->Op(); Node* ffn_layer_norm_out) {
auto* matmul_linear_op = matmul_linear->Op(); auto reshape_desc = reshape2_0->Op();
auto* ffn_matmul_0_op = ffn_matmul0->Op(); int num_head =
auto* ffn_matmul_1_op = ffn_matmul1->Op(); PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(2);
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3);
// Calc index of transformer layer by LayerNorm Scale name // Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes: // This calculation assumes:
...@@ -2228,47 +3450,47 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2228,47 +3450,47 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1); auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2; int layer_idx = atoi(ln_idx_str.c_str()) / 2;
auto* qkv_w_tensor = auto* wq_tensor =
scope->FindVar(matmul0_w->Name())->GetMutable<phi::DenseTensor>(); scope->FindVar(matmul0_w->Name())->GetMutable<phi::DenseTensor>();
auto* qkv_b_tensor = auto* wk_tensor =
scope->FindVar(matmul1_w->Name())->GetMutable<phi::DenseTensor>();
auto* wv_tensor =
scope->FindVar(matmul2_w->Name())->GetMutable<phi::DenseTensor>();
auto* bq_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<phi::DenseTensor>(); scope->FindVar(eltadd0_b->Name())->GetMutable<phi::DenseTensor>();
auto* bk_tensor =
scope->FindVar(eltadd1_b->Name())->GetMutable<phi::DenseTensor>();
auto* bv_tensor =
scope->FindVar(eltadd2_b->Name())->GetMutable<phi::DenseTensor>();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on int dim_embed = wq_tensor->dims()[0];
// num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0]
// 2. calculate num_head according to wqkv_tensor.shape[1]/3 and dim_head
auto reshape_desc = reshape2_0->Op();
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3) /
3; // 3 for qkv
auto* layer_norm_bias_tensor =
scope->FindVar(layer_norm_bias->Name())->GetMutable<phi::DenseTensor>();
int dim_embed = layer_norm_bias_tensor->dims()[0];
int num_head = qkv_w_tensor->dims()[1] / 3 / dim_head;
QKVWeightsBiasProcessFuseQKV( QKVWeightsBiasProcess(wq_tensor,
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); wk_tensor,
wv_tensor,
bq_tensor,
bk_tensor,
bv_tensor,
num_head,
dim_head,
dim_embed);
if (enable_int8) { // reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto* out_linear_w_tensor = scope->FindVar(matmul_linear_w->Name()) auto* combined_w_desc = matmul0_w->Var();
->GetMutable<phi::DenseTensor>(); combined_w_desc->SetShape({3, num_head, dim_head, dim_embed});
auto* ffn0_w_tensor = combined_w_desc->SetPersistable(true);
scope->FindVar(ffn_matmul0_w->Name())->GetMutable<phi::DenseTensor>();
auto* ffn1_w_tensor =
scope->FindVar(ffn_matmul1_w->Name())->GetMutable<phi::DenseTensor>();
TransposeWeights(out_linear_w_tensor); auto* combined_bias_desc = eltadd0_b->Var();
TransposeWeights(ffn0_w_tensor); combined_bias_desc->SetShape({3, num_head, dim_head});
TransposeWeights(ffn1_w_tensor); combined_bias_desc->SetPersistable(true);
}
scope->EraseVars({matmul1_w->Name(), matmul2_w->Name()});
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
// create fused_multi_transformer // create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType(enable_int8 fused_multi_transformer_op_desc.SetType("fused_multi_transformer");
? "fused_multi_transformer_int8"
: "fused_multi_transformer");
// 1. Input setting // 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
...@@ -2288,7 +3510,7 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2288,7 +3510,7 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx)); VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx));
// FIXME: only support max_seq_len <= 1024 // FIXME: only support max_seq_len <= 1024
cache_kv_desc.SetDataType( cache_kv_desc.SetDataType(
framework::TransToProtoVarType(qkv_b_tensor->dtype())); framework::TransToProtoVarType(wq_tensor->dtype()));
cache_kv_desc.SetPersistable(false); cache_kv_desc.SetPersistable(false);
auto* cache_kv = graph->CreateVarNode(&cache_kv_desc); auto* cache_kv = graph->CreateVarNode(&cache_kv_desc);
...@@ -2301,9 +3523,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2301,9 +3523,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fill_const_op_desc.SetAttr("input_dim_idx", 0); fill_const_op_desc.SetAttr("input_dim_idx", 0);
fill_const_op_desc.SetAttr("output_dim_idx", 1); fill_const_op_desc.SetAttr("output_dim_idx", 1);
fill_const_op_desc.SetAttr("value", 0); fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr("dtype", fill_const_op_desc.SetAttr(
static_cast<int>(framework::TransToProtoVarType( "dtype",
qkv_b_tensor->dtype()))); static_cast<int>(framework::TransToProtoVarType(wq_tensor->dtype())));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
...@@ -2329,137 +3551,27 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2329,137 +3551,27 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
{ffn_eltadd1_b->Name()}); {ffn_eltadd1_b->Name()});
// 2. Output setting // 2. Output setting
fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()}); fused_multi_transformer_op_desc.SetOutput("Out",
{ffn_layer_norm_out->Name()});
fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv->Name()}); fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv->Name()});
// Attribute setting // Attribute setting
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", false);
fused_multi_transformer_op_desc.SetAttr( fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon")); "epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute
fused_multi_transformer_op_desc.SetAttr("is_test", true); fused_multi_transformer_op_desc.SetAttr("is_test", true);
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
fused_multi_transformer_op_desc.SetAttr("act_method",
{ffn_act->Op()->Type()});
// Quantization attribute/Input // parallel ring id
if (enable_int8) { auto* c_identity_op = c_identity->Op();
// Set input scale fused_multi_transformer_op_desc.SetAttr("ring_id",
std::string qkv_input_name = matmul0_op->Input("X")[0]; c_identity_op->GetAttr("ring_id"));
auto qkv_in_scale = PADDLE_GET_CONST(
float, matmul0_op->GetAttr("Input_scale_" + qkv_input_name));
std::string out_linear_input_name = matmul_linear_op->Input("X")[0];
auto out_linear_in_scale = PADDLE_GET_CONST(
float,
matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name));
std::string ffn0_input_name = ffn_matmul_0_op->Input("X")[0];
auto ffn0_in_scale = PADDLE_GET_CONST(
float, ffn_matmul_0_op->GetAttr("Input_scale_" + ffn0_input_name));
std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0];
auto ffn1_in_scale = PADDLE_GET_CONST(
float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name));
// Calc outscale and Set them
// TODO(wufeisheng): Currently just match layer-wise weight scale, where
// channel-wise weight scale should also be surpported.
auto qkv_weight_scale =
PADDLE_GET_CONST(float, matmul0_op->GetAttr("weight_scale"));
auto out_weight_scale =
PADDLE_GET_CONST(float, matmul_linear_op->GetAttr("weight_scale"));
auto ffn0_weight_scale =
PADDLE_GET_CONST(float, ffn_matmul_0_op->GetAttr("weight_scale"));
auto ffn1_weight_scale =
PADDLE_GET_CONST(float, ffn_matmul_1_op->GetAttr("weight_scale"));
auto qkv_out_scales = std::vector<float>(
3 * dim_embed, (qkv_weight_scale / 127.0f) * (qkv_in_scale / 127.0f));
auto out_out_scales = std::vector<float>(
dim_embed,
(out_weight_scale / 127.0f) * (out_linear_in_scale / 127.0f));
auto ffn0_out_scales = std::vector<float>(
4 * dim_embed,
(ffn0_weight_scale / 127.0f) * (ffn0_in_scale / 127.0f));
auto ffn1_out_scales = std::vector<float>(
dim_embed, (ffn1_weight_scale / 127.0f) * (ffn1_in_scale / 127.0f));
// Inverse input scale
qkv_in_scale = 1.0f / qkv_in_scale;
out_linear_in_scale = 1.0f / out_linear_in_scale;
ffn0_in_scale = 1.0f / ffn0_in_scale;
ffn1_in_scale = 1.0f / ffn1_in_scale;
fused_multi_transformer_op_desc.SetAttr("qkv_in_scale",
std::vector<float>{qkv_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"out_linear_in_scale", std::vector<float>{out_linear_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn1_in_scale", std::vector<float>{ffn0_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn2_in_scale", std::vector<float>{ffn1_in_scale});
auto qkv_out_scale_var = scope->Var(matmul0_w->Name() + "_out_scale");
auto out_out_scale_var =
scope->Var(matmul_linear_w->Name() + "_out_scale");
auto ffn0_out_scale_var =
scope->Var(ffn_matmul0_w->Name() + "_out_scale");
auto ffn1_out_scale_var =
scope->Var(ffn_matmul1_w->Name() + "_out_scale");
auto qkv_out_scale_data =
qkv_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({3 * dim_embed}, platform::CPUPlace());
memcpy(qkv_out_scale_data,
qkv_out_scales.data(),
qkv_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"QKVOutScale", {matmul0_w->Name() + "_out_scale"});
auto out_out_scale_data =
out_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({dim_embed}, platform::CPUPlace());
memcpy(out_out_scale_data,
out_out_scales.data(),
out_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"});
auto ffn0_out_scale_data =
ffn0_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({4 * dim_embed}, platform::CPUPlace());
memcpy(ffn0_out_scale_data,
ffn0_out_scales.data(),
ffn0_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"});
auto ffn1_out_scale_data =
ffn1_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({dim_embed}, platform::CPUPlace());
memcpy(ffn1_out_scale_data,
ffn1_out_scales.data(),
ffn1_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"});
}
auto* fused_multi_transformer = auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc); graph->CreateOpNode(&fused_multi_transformer_op_desc);
if (enable_int8) {
auto qkv_out_scale_node =
CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale");
auto out_out_scale_node = CreatePersistableVarNode(
graph, matmul_linear_w->Name() + "_out_scale");
auto ffn0_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale");
auto ffn1_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale");
IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer);
}
IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
...@@ -2477,291 +3589,374 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2477,291 +3589,374 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer); IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer); IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer); IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer); IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer); IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); IR_NODE_LINK_TO(fused_multi_transformer, ffn_layer_norm_out);
// rewrite while OP input
// 1. delete k, v
// 2. delete matmul1/2_w eltadd1/2_w
// 3. add cache_kv
auto while_Xs = while0->Op()->Input("X");
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), split0_k_out->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), split0_v_out->Name()),
std::end(while_Xs));
while_Xs.emplace_back(cache_kv->Name());
while0->Op()->SetInput("X", while_Xs);
// rewrite while OP output
// 1. delete k, v
// 2. add cache_kv
auto while_Outs = while0->Op()->Output("Out");
while_Outs.erase(
std::remove(
std::begin(while_Outs), std::end(while_Outs), split0_k_out->Name()),
std::end(while_Outs));
while_Outs.erase(
std::remove(
std::begin(while_Outs), std::end(while_Outs), split0_v_out->Name()),
std::end(while_Outs));
while_Outs.emplace_back(cache_kv->Name());
while0->Op()->SetOutput("Out", while_Outs);
// link CacheKV to while
IR_NODE_LINK_TO(cache_kv, while0)
// unlink origin KV output to while
IR_NODE_UNLINK(split0_k_out, while0);
IR_NODE_UNLINK(split0_v_out, while0);
IR_NODE_UNLINK(while0, split0_k_out);
IR_NODE_UNLINK(while0, split0_v_out);
}; };
int fusion_count{0}; int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
if (!IsCompat(subgraph, graph)) { if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_encoder_fuse_qkv " LOG(WARNING) << "fused_multi_transformer_encoder pass in "
"pass in op compat failed."; "op compat failed.";
return; return;
} }
VLOG(4) << "handle MultiTransformer encoder(Fuse-QKV) fuse"; VLOG(4) << "handle MultiTransformer encoder fuse";
GET_IR_NODE_FROM_SUBGRAPH(
input0, input0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern); input0, input0, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_identity0,
c_identity0,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_identity0_out,
c_identity0_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_identity1,
c_identity1,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_identity1_out,
c_identity1_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_identity2,
c_identity2,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_identity2_out,
c_identity2_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale,
layer_norm_scale, layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias,
layer_norm_bias, layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean,
layer_norm_mean, layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance, layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out,
layer_norm_out, layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity,
ffn_c_identity,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity_out,
ffn_c_identity_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern); matmul0, matmul0, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(matmul0_out,
matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern); matmul0_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern); matmul0_w, matmul0_w, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern); reshape2_0, reshape2_0, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out,
reshape2_0_out, reshape2_0_out,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(transpose2_0,
transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern); transpose2_0,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out,
transpose2_0_out, transpose2_0_out,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
split0, split0, fused_multi_transformer_fuse_qkv_pattern); matmul1, matmul1, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul1_out,
matmul1_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern); matmul1_w, matmul1_w, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern); reshape2_1, reshape2_1, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out,
reshape2_1_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1,
transpose2_1,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out,
transpose2_1_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern); scale_q, scale_q, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_q_out,
scale_q_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2, matmul2, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul2_out,
matmul2_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2_w, matmul2_w, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2, reshape2_2, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out,
reshape2_2_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2,
transpose2_2,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out,
transpose2_2_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(attention_output,
attention_output,
multi_devices_fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm, GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm,
ffn_layer_norm, ffn_layer_norm,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale, GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale, ffn_layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias, GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias, ffn_layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean, GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean, ffn_layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance, GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance, ffn_layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out, GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out, ffn_layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0,
ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern); ffn_matmul0,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out, GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out,
ffn_matmul0_out, ffn_matmul0_out,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_w,
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern); ffn_matmul0_w,
GET_IR_NODE_FROM_SUBGRAPH( multi_devices_fused_multi_transformer_pattern);
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0,
GET_IR_NODE_FROM_SUBGRAPH( ffn_eltadd0,
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_b,
ffn_eltadd0_b,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out, GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out,
ffn_eltadd0_out, ffn_eltadd0_out,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern); ffn_act, ffn_act, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(ffn_act_out,
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern); ffn_act_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1,
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern); ffn_matmul1,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out, GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out,
ffn_matmul1_out, ffn_matmul1_out,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_w,
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern); ffn_matmul1_w,
GET_IR_NODE_FROM_SUBGRAPH( multi_devices_fused_multi_transformer_pattern);
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum,
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern); ffn_c_allreduce_sum,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum_out,
ffn_c_allreduce_sum_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1,
ffn_eltadd1,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_b,
ffn_eltadd1_b,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out, GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out,
ffn_eltadd1_out, ffn_eltadd1_out,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out,
ffn_eltadd_out, ffn_eltadd_out,
fused_multi_transformer_fuse_qkv_pattern) multi_devices_fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern) ffn_output, ffn_output, multi_devices_fused_multi_transformer_pattern)
// nodes need be removed // nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern); eltadd0, eltadd0, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern); eltadd0_b, eltadd0_b, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd0_out,
eltadd0_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern); eltadd1, eltadd1, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); eltadd1_b, eltadd1_b, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out,
eltadd1_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern); eltadd2, eltadd2, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern); eltadd2_b, eltadd2_b, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2_out,
eltadd2_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); matmul_qk, matmul_qk, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out,
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern); matmul_qk_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern); eltadd_qk, eltadd_qk, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_b,
eltadd_qk_b,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out,
eltadd_qk_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern); softmax_qk, softmax_qk, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out,
softmax_qk_out, softmax_qk_out,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); matmul_qkv, matmul_qkv, multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out,
matmul_qkv_out, matmul_qkv_out,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv,
reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern); reshape2_qkv,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out,
reshape2_qkv_out, reshape2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv,
transpose2_qkv, transpose2_qkv,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out, transpose2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern); multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(matmul_linear,
matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern) matmul_linear,
multi_devices_fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w, GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w,
matmul_linear_w, matmul_linear_w,
fused_multi_transformer_fuse_qkv_pattern) multi_devices_fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out, GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out,
matmul_linear_out, matmul_linear_out,
fused_multi_transformer_fuse_qkv_pattern) multi_devices_fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum,
eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern) c_allreduce_sum,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum_out,
c_allreduce_sum_out,
multi_devices_fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear,
eltadd_linear,
multi_devices_fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b, GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b,
eltadd_linear_b, eltadd_linear_b,
fused_multi_transformer_fuse_qkv_pattern) multi_devices_fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out,
eltadd_linear_out, eltadd_linear_out,
fused_multi_transformer_fuse_qkv_pattern) multi_devices_fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
while0, while0, fused_multi_transformer_fuse_qkv_pattern) eltadd_out, eltadd_out, multi_devices_fused_multi_transformer_pattern)
fuse_creater(input0, fuse_creater(input0,
c_identity0,
layer_norm, layer_norm,
layer_norm_scale, layer_norm_scale,
layer_norm_bias, layer_norm_bias,
layer_norm_mean, layer_norm_mean,
layer_norm_variance, layer_norm_variance,
matmul0,
matmul0_w, matmul0_w,
matmul1_w,
matmul2_w,
eltadd0_b, eltadd0_b,
split0_k_out, eltadd1_b,
split0_v_out, eltadd2_b,
transpose2_1_out,
transpose2_2_out,
eltadd_qk_b, eltadd_qk_b,
reshape2_0, reshape2_0,
matmul_linear,
matmul_linear_w, matmul_linear_w,
eltadd_linear_b, eltadd_linear_b,
while0,
ffn_layer_norm, ffn_layer_norm,
ffn_layer_norm_scale, ffn_layer_norm_scale,
ffn_layer_norm_bias, ffn_layer_norm_bias,
ffn_layer_norm_mean, ffn_layer_norm_mean,
ffn_layer_norm_variance, ffn_layer_norm_variance,
ffn_matmul0,
ffn_matmul0_w, ffn_matmul0_w,
ffn_matmul1,
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
ffn_output); ffn_act,
ffn_layer_norm_out);
std::unordered_set<const Node*> marked_nodes({layer_norm,
std::unordered_set<const Node*> marked_nodes({c_identity0,
c_identity0_out,
c_identity1,
c_identity1_out,
c_identity2,
c_identity2_out,
layer_norm,
layer_norm_mean, layer_norm_mean,
layer_norm_variance, layer_norm_variance,
layer_norm_out, layer_norm_out,
matmul0, matmul0,
matmul1,
matmul2,
matmul0_out, matmul0_out,
matmul1_out,
matmul2_out,
eltadd0, eltadd0,
eltadd1,
eltadd2,
eltadd0_out, eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0, reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out, reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0, transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out, transpose2_0_out,
split0, transpose2_1_out,
split0_q_out, transpose2_2_out,
split0_k_out, scale_q,
split0_v_out, scale_q_out,
matmul_qk, matmul_qk,
matmul_qk_out, matmul_qk_out,
scale_qk,
scale_qk_out,
eltadd_qk, eltadd_qk,
eltadd_qk_out, eltadd_qk_out,
softmax_qk, softmax_qk,
...@@ -2775,23 +3970,29 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2775,23 +3970,29 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
transpose2_qkv_out, transpose2_qkv_out,
matmul_linear, matmul_linear,
matmul_linear_out, matmul_linear_out,
c_allreduce_sum,
c_allreduce_sum_out,
eltadd_linear, eltadd_linear,
eltadd_linear_out, eltadd_linear_out,
eltadd_out, eltadd_out,
ffn_layer_norm, ffn_layer_norm,
ffn_layer_norm_mean, ffn_layer_norm_mean,
ffn_layer_norm_variance, ffn_layer_norm_variance,
ffn_layer_norm_out, ffn_c_identity,
ffn_c_identity_out,
ffn_matmul0, ffn_matmul0,
ffn_matmul1, ffn_matmul1,
ffn_matmul0_out, ffn_matmul0_out,
ffn_matmul1_out, ffn_matmul1_out,
ffn_c_allreduce_sum,
ffn_c_allreduce_sum_out,
ffn_eltadd0, ffn_eltadd0,
ffn_eltadd1, ffn_eltadd1,
ffn_eltadd0_out, ffn_eltadd0_out,
ffn_eltadd1_out, ffn_eltadd1_out,
ffn_gelu, ffn_act,
ffn_gelu_out, ffn_act_out,
ffn_output,
ffn_eltadd_out}); ffn_eltadd_out});
// Remove unneeded nodes. // Remove unneeded nodes.
...@@ -2803,25 +4004,25 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2803,25 +4004,25 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
return fusion_count; return fusion_count;
} }
void FusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(Graph* graph) const { void MultiDevicesFusedMultiTransformerEncoderPass::ApplyImpl(
Graph* graph) const {
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
scope, scope,
platform::errors::Fatal( platform::errors::Fatal(
"During the fused_multi_transformer_encoder pass, " "During the multi_transformer pass, The scope should not be null."));
"The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope); int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) { if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderFuseQKVPass, new bool(true)); graph->Set(kFusedMultiTransformerEncoderPass, new bool(true));
graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count)); graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count));
} }
AddStatis(fusion_count); AddStatis(fusion_count);
} }
FusedMultiTransformerEncoderFuseQKVPass:: MultiDevicesFusedMultiTransformerEncoderPass::
FusedMultiTransformerEncoderFuseQKVPass() { MultiDevicesFusedMultiTransformerEncoderPass() {
AddOpCompat(OpCompat("layer_norm")) AddOpCompat(OpCompat("layer_norm"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -2849,23 +4050,6 @@ FusedMultiTransformerEncoderFuseQKVPass:: ...@@ -2849,23 +4050,6 @@ FusedMultiTransformerEncoderFuseQKVPass::
.IsNumGT(0) .IsNumGT(0)
.End(); .End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2")) AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H) .AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor() .IsTensor()
...@@ -2935,24 +4119,20 @@ FusedMultiTransformerEncoderFuseQKVPass:: ...@@ -2935,24 +4119,20 @@ FusedMultiTransformerEncoderFuseQKVPass::
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End(); .End();
AddOpCompat(OpCompat("matmul")) AddOpCompat(OpCompat("scale"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
.End() .End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out") .AddOutput("Out")
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("alpha") .AddAttr("scale")
.IsNumGE(0.0f) .IsType<float>() // copy to new op. so unconstrained.
.IsNumLE(1.0f)
.End() .End()
.AddAttr("transpose_X") .AddAttr("bias")
.IsBoolEQ(false) .IsNumEQ(0.f)
.End() .End()
.AddAttr("transpose_Y") .AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>() .IsType<bool>()
.End(); .End();
...@@ -2978,18 +4158,12 @@ FusedMultiTransformerEncoderFuseQKVPass:: ...@@ -2978,18 +4158,12 @@ FusedMultiTransformerEncoderFuseQKVPass::
.IsType<bool>() .IsType<bool>()
.End(); .End();
AddOpCompat(OpCompat("while")) AddOpCompat(OpCompat("relu"))
.AddInput("X") // A set of variables, unconstrained .AddInput("X")
.End()
.AddInput("Condition") // An scalar
.IsTensor() .IsTensor()
.End() .End()
.AddOutput("Out") // A set of variables, unconstrained .AddOutput("Out")
.End() .IsTensor()
.AddOutput("StepScopes") // A vector of local scope, unconstrained
.End()
.AddAttr("sub_block")
.IsType<framework::BlockDesc*>()
.End(); .End();
} }
...@@ -3040,6 +4214,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3040,6 +4214,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_act,
Node* ffn_output) { Node* ffn_output) {
auto* matmul0_op = matmul0->Op(); auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op(); auto* matmul_linear_op = matmul_linear->Op();
...@@ -3163,6 +4338,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3163,6 +4338,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr( fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon")); "epsilon", layer_norm->Op()->GetAttr("epsilon"));
fused_multi_transformer_op_desc.SetAttr("act_method",
ffn_act->Op()->Type());
// output dropout attribute // output dropout attribute
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
...@@ -3175,6 +4352,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3175,6 +4352,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
// Quantization attribute/Input // Quantization attribute/Input
if (enable_int8) { if (enable_int8) {
auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()));
// Set input scale // Set input scale
std::string matmul_input_scale_suffix = c_identity_op->Input("X")[0]; std::string matmul_input_scale_suffix = c_identity_op->Input("X")[0];
auto qkv_in_scale = PADDLE_GET_CONST( auto qkv_in_scale = PADDLE_GET_CONST(
...@@ -3240,36 +4419,44 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3240,36 +4419,44 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
auto ffn1_out_scale_var = auto ffn1_out_scale_var =
scope->Var(ffn_matmul1_w->Name() + "_out_scale"); scope->Var(ffn_matmul1_w->Name() + "_out_scale");
auto qkv_out_scale_data = auto* qkv_out_scale_tensor =
qkv_out_scale_var->GetMutable<phi::DenseTensor>() qkv_out_scale_var->GetMutable<phi::DenseTensor>();
->mutable_data<float>({3 * dim_embed}, platform::CPUPlace()); qkv_out_scale_tensor->Resize({3 * dim_embed});
dev_ctx->Alloc<float>(qkv_out_scale_tensor);
auto qkv_out_scale_data = qkv_out_scale_tensor->data<float>();
memcpy(qkv_out_scale_data, memcpy(qkv_out_scale_data,
qkv_out_scales.data(), qkv_out_scales.data(),
qkv_out_scales.size() * sizeof(float)); qkv_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput( fused_multi_transformer_op_desc.SetInput(
"QKVOutScale", {matmul0_w->Name() + "_out_scale"}); "QKVOutScale", {matmul0_w->Name() + "_out_scale"});
auto out_out_scale_data = auto* out_out_scale_tensor =
out_out_scale_var->GetMutable<phi::DenseTensor>() out_out_scale_var->GetMutable<phi::DenseTensor>();
->mutable_data<float>({dim_embed}, platform::CPUPlace()); out_out_scale_tensor->Resize({dim_embed});
dev_ctx->Alloc<float>(out_out_scale_tensor);
auto out_out_scale_data = out_out_scale_tensor->data<float>();
memcpy(out_out_scale_data, memcpy(out_out_scale_data,
out_out_scales.data(), out_out_scales.data(),
out_out_scales.size() * sizeof(float)); out_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput( fused_multi_transformer_op_desc.SetInput(
"OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"}); "OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"});
auto ffn0_out_scale_data = auto* ffn0_out_scale_tensor =
ffn0_out_scale_var->GetMutable<phi::DenseTensor>() ffn0_out_scale_var->GetMutable<phi::DenseTensor>();
->mutable_data<float>({4 * dim_embed}, platform::CPUPlace()); ffn0_out_scale_tensor->Resize({4 * dim_embed});
dev_ctx->Alloc<float>(ffn0_out_scale_tensor);
auto ffn0_out_scale_data = ffn0_out_scale_tensor->data<float>();
memcpy(ffn0_out_scale_data, memcpy(ffn0_out_scale_data,
ffn0_out_scales.data(), ffn0_out_scales.data(),
ffn0_out_scales.size() * sizeof(float)); ffn0_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput( fused_multi_transformer_op_desc.SetInput(
"FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"}); "FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"});
auto ffn1_out_scale_data = auto* ffn1_out_scale_tensor =
ffn1_out_scale_var->GetMutable<phi::DenseTensor>() ffn1_out_scale_var->GetMutable<phi::DenseTensor>();
->mutable_data<float>({dim_embed}, platform::CPUPlace()); ffn1_out_scale_tensor->Resize({dim_embed});
dev_ctx->Alloc<float>(ffn1_out_scale_tensor);
auto ffn1_out_scale_data = ffn1_out_scale_tensor->data<float>();
memcpy(ffn1_out_scale_data, memcpy(ffn1_out_scale_data,
ffn1_out_scales.data(), ffn1_out_scales.data(),
ffn1_out_scales.size() * sizeof(float)); ffn1_out_scales.size() * sizeof(float));
...@@ -3464,9 +4651,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3464,9 +4651,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern); ffn_act, ffn_act, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern); ffn_act_out, ffn_act_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern); ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
...@@ -3603,6 +4790,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3603,6 +4790,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
ffn_act,
ffn_output); ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm, std::unordered_set<const Node*> marked_nodes({layer_norm,
...@@ -3661,8 +4849,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3661,8 +4849,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_eltadd1, ffn_eltadd1,
ffn_eltadd0_out, ffn_eltadd0_out,
ffn_eltadd1_out, ffn_eltadd1_out,
ffn_gelu, ffn_act,
ffn_gelu_out, ffn_act_out,
ffn_eltadd_out}); ffn_eltadd_out});
// Remove unneeded nodes. // Remove unneeded nodes.
...@@ -3874,6 +5062,9 @@ REGISTER_PASS(fused_multi_transformer_encoder_pass, ...@@ -3874,6 +5062,9 @@ REGISTER_PASS(fused_multi_transformer_encoder_pass,
paddle::framework::ir::FusedMultiTransformerEncoderPass); paddle::framework::ir::FusedMultiTransformerEncoderPass);
REGISTER_PASS(fused_multi_transformer_encoder_fuse_qkv_pass, REGISTER_PASS(fused_multi_transformer_encoder_fuse_qkv_pass,
paddle::framework::ir::FusedMultiTransformerEncoderFuseQKVPass); paddle::framework::ir::FusedMultiTransformerEncoderFuseQKVPass);
REGISTER_PASS(
multi_devices_fused_multi_transformer_encoder_pass,
paddle::framework::ir::MultiDevicesFusedMultiTransformerEncoderPass);
REGISTER_PASS( REGISTER_PASS(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass, multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass,
paddle::framework::ir::MultiDevicesFusedMultiTransformerEncoderFuseQKVPass); paddle::framework::ir::MultiDevicesFusedMultiTransformerEncoderFuseQKVPass);
...@@ -3898,6 +5089,16 @@ REGISTER_PASS_CAPABILITY(fused_multi_transformer_encoder_fuse_qkv_pass) ...@@ -3898,6 +5089,16 @@ REGISTER_PASS_CAPABILITY(fused_multi_transformer_encoder_fuse_qkv_pass)
.LE("matmul", 1) .LE("matmul", 1)
.EQ("matmul_v2", 0) .EQ("matmul_v2", 0)
.EQ("softmax", 0)); .EQ("softmax", 0));
REGISTER_PASS_CAPABILITY(multi_devices_fused_multi_transformer_encoder_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
REGISTER_PASS_CAPABILITY( REGISTER_PASS_CAPABILITY(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass) multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass)
.AddCombination( .AddCombination(
......
...@@ -37,12 +37,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { ...@@ -37,12 +37,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
// Q, K, V path // Q, K, V path
PATTERN_DECL_NODE(input0); PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0); PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul1); PATTERN_DECL_NODE(matmul1);
PATTERN_DECL_NODE(matmul2); PATTERN_DECL_NODE(matmul2);
...@@ -73,6 +67,8 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { ...@@ -73,6 +67,8 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE(transpose2_0_out); PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out); PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out); PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(scale_q);
PATTERN_DECL_NODE(scale_q_out);
// Q, K matmul // Q, K matmul
PATTERN_DECL_NODE(matmul_qk); PATTERN_DECL_NODE(matmul_qk);
...@@ -98,29 +94,30 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { ...@@ -98,29 +94,30 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out); PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(eltadd_out) PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output); PATTERN_DECL_NODE(attention_output);
// while loop // post layer_norm
PATTERN_DECL_NODE(while0); PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
// Feed Forward nodes // Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0); PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w); PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out); PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out); PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu); PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_gelu_out); PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1); PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w); PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out); PATTERN_DECL_NODE(ffn_matmul1_out);
...@@ -131,6 +128,13 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { ...@@ -131,6 +128,13 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output); PATTERN_DECL_NODE(ffn_output);
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
}; };
struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
...@@ -212,8 +216,8 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { ...@@ -212,8 +216,8 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out); PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu); PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_gelu_out); PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1); PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w); PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out); PATTERN_DECL_NODE(ffn_matmul1_out);
...@@ -226,6 +230,129 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { ...@@ -226,6 +230,129 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_output); PATTERN_DECL_NODE(ffn_output);
}; };
struct MultiDevicesFusedMultiTransformerEncoderPattern : public PatternBase {
MultiDevicesFusedMultiTransformerEncoderPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern,
name_scope,
"multi_devices_fused_multi_transformer_encoder") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(c_identity0);
PATTERN_DECL_NODE(c_identity0_out);
PATTERN_DECL_NODE(c_identity1);
PATTERN_DECL_NODE(c_identity1_out);
PATTERN_DECL_NODE(c_identity2);
PATTERN_DECL_NODE(c_identity2_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul1);
PATTERN_DECL_NODE(matmul2);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul1_w);
PATTERN_DECL_NODE(matmul2_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(matmul1_out);
PATTERN_DECL_NODE(matmul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(scale_q);
PATTERN_DECL_NODE(scale_q_out);
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(c_allreduce_sum);
PATTERN_DECL_NODE(c_allreduce_sum_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// post layer_norm
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_c_identity);
PATTERN_DECL_NODE(ffn_c_identity_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_c_allreduce_sum);
PATTERN_DECL_NODE(ffn_c_allreduce_sum_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
};
struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
: public PatternBase { : public PatternBase {
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern( MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern(
...@@ -313,8 +440,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern ...@@ -313,8 +440,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out); PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu); PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_gelu_out); PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1); PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w); PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out); PATTERN_DECL_NODE(ffn_matmul1_out);
...@@ -362,6 +489,23 @@ class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase { ...@@ -362,6 +489,23 @@ class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase {
Scope* scope) const; Scope* scope) const;
}; };
class MultiDevicesFusedMultiTransformerEncoderPass : public FusePassBase {
public:
MultiDevicesFusedMultiTransformerEncoderPass();
virtual ~MultiDevicesFusedMultiTransformerEncoderPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{
"multi_devices_fused_multi_transformer_encoder"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class MultiDevicesFusedMultiTransformerEncoderFuseQKVPass class MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
: public FusePassBase { : public FusePassBase {
public: public:
......
...@@ -56,8 +56,8 @@ Scope* CreateParamScope() { ...@@ -56,8 +56,8 @@ Scope* CreateParamScope() {
// FFN: fc1 -> (gelu) -> fc2 // FFN: fc1 -> (gelu) -> fc2
AddVarToScope(param_scope, "ffn_weights0", {1024, 4096}); AddVarToScope(param_scope, "ffn_weights0", {1024, 4096});
AddVarToScope(param_scope, "ffn_weights1", {4096, 1024}); AddVarToScope(param_scope, "ffn_weights1", {4096, 1024});
AddVarToScope(param_scope, "ffn_bias_0", {4096}); AddVarToScope(param_scope, "ffn_bias0", {4096});
AddVarToScope(param_scope, "ffn_bias_1", {1024}); AddVarToScope(param_scope, "ffn_bias1", {1024});
return param_scope; return param_scope;
} }
...@@ -65,10 +65,9 @@ Scope* CreateParamScope() { ...@@ -65,10 +65,9 @@ Scope* CreateParamScope() {
TEST(FusedMultiTransformerEncoderPass, basic) { TEST(FusedMultiTransformerEncoderPass, basic) {
// inputs operator output // inputs operator output
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out // (x, weights_0) matmul_v2 -> matmul_out0
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0 // (x, weights_1) matmul_v2 -> matmul_out1
// (layer_norm_out, weights_1) matmul_v2 -> matmul_out1 // (x, weights_2) matmul_v2 -> matmul_out2
// (layer_norm_out, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0 // (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1 // (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2 // (matmul_out2, bias_2) elementwise_add -> eltadd_2
...@@ -78,7 +77,8 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -78,7 +77,8 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (reshape_0) transpose2 -> transpose_0 // (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1 // (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2 // (reshape_2) transpose2 -> transpose_2
// (transpose_0, transpose_1) matmul -> matmul_qk // (transpose_0) scale -> scale_0
// (scale_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
...@@ -86,35 +86,28 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -86,35 +86,28 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear // (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear // (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_out) elementwise_add -> attention_out // (eltadd_linear) elementwise_add -> attention_out
// //
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out // (attention_out, scale, bias) layer_norm -> layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0 // (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0 // (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu // (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (attention_out, ffn_eltadd1) elementwise_add -> ffn_output // (layer_norm_out, ffn_eltadd1) elementwise_add -> ffn_output
// // (ffn_output, scale, bias) layer_norm -> ffn_layer_norm_out
// (transpose_1, transpose_2) while -> decoder block
Layers layers; Layers layers;
// MHA: pre LayerNorm // MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024}); auto* x = layers.data("x", {1, 128, 1024});
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0];
// MHA: QKV fc // MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 1024}, true); auto* weights_0 = layers.data("weights0", {1024, 1024}, true);
auto* weights_1 = layers.data("weights1", {1024, 1024}, true); auto* weights_1 = layers.data("weights1", {1024, 1024}, true);
auto* weights_2 = layers.data("weights2", {1024, 1024}, true); auto* weights_2 = layers.data("weights2", {1024, 1024}, true);
auto* matmul_out_0 = auto* matmul_out_0 = layers.matmul_v2(x, weights_0, nullptr, false, false);
layers.matmul_v2(ln_out, weights_0, nullptr, false, true); auto* matmul_out_1 = layers.matmul_v2(x, weights_1, nullptr, false, false);
auto* matmul_out_1 = auto* matmul_out_2 = layers.matmul_v2(x, weights_2, nullptr, false, false);
layers.matmul_v2(ln_out, weights_1, nullptr, false, true);
auto* matmul_out_2 =
layers.matmul_v2(ln_out, weights_2, nullptr, false, true);
auto* b0 = layers.data("bias_0", {1024}, true); auto* b0 = layers.data("bias_0", {1024}, true);
auto* b1 = layers.data("bias_1", {1024}, true); auto* b1 = layers.data("bias_1", {1024}, true);
...@@ -136,14 +129,13 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -136,14 +129,13 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto* transpose_1 = layers.transpose2(reshape_1, axis, true); auto* transpose_1 = layers.transpose2(reshape_1, axis, true);
auto* transpose_2 = layers.transpose2(reshape_2, axis, true); auto* transpose_2 = layers.transpose2(reshape_2, axis, true);
// Link to decoder while block // q scale
layers.while_loop({transpose_1, transpose_2}); auto* scale_q = layers.scale(transpose_0, 0.125, 0, false);
// MHA: QK matmul // MHA: QK matmul
auto* matmul_qk = auto* matmul_qk =
layers.matmul(transpose_0, transpose_1, nullptr, false, true); layers.matmul_v2(scale_q, transpose_1, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 1, 1, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
...@@ -155,19 +147,18 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -155,19 +147,18 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// MHA: out Linear // MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true); auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("weightsl", {1024, 1024}, true); auto* bias_l = layers.data("bias_l", {1024}, true);
auto* linear_matmut_out = auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true); layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, false);
auto* linear_eltadd_out = auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* attention_out = layers.elementwise_add(x, linear_eltadd_out); auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
// FFN: pre LayerNorm // post LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ffn_ln_out = auto* ln_out = layers.layer_norm(attention_out, ln_scale, ln_bias)[0];
layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0];
// FFN: fc1 -> gelu -> fc2 // FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true); auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
...@@ -175,7 +166,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -175,7 +166,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true); auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true); auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out = auto* ffn_matmul0_out =
layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true); layers.matmul_v2(ln_out, ffn_weights0, nullptr, false, true);
auto* ffn_eltadd0_out = auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2); layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out); auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
...@@ -184,7 +175,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -184,7 +175,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto* ffn_eltadd1_out = auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
layers.elementwise_add(attention_out, ffn_eltadd1_out); auto* ffn_out = layers.elementwise_add(ln_out, ffn_eltadd1_out);
// FFN: post LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
layers.layer_norm(ffn_out, ffn_ln_scale, ffn_ln_bias)[0];
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -203,12 +199,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -203,12 +199,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(num_nodes_before, PADDLE_ENFORCE_EQ(num_nodes_before,
num_nodes_after + 56, num_nodes_after + 58,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_pass, The " "After the fused_multi_transformer_encoder_pass, The "
"node num in graph " "node num in graph "
"should be %d, but the result is %d", "should be %d, but the result is %d",
num_nodes_before - 56, num_nodes_before - 58,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
...@@ -225,6 +221,183 @@ TEST(FusedMultiTransformerEncoderPass, pass_op_version_check) { ...@@ -225,6 +221,183 @@ TEST(FusedMultiTransformerEncoderPass, pass_op_version_check) {
.IsPassCompatible("fused_multi_transformer_encoder_pass")); .IsPassCompatible("fused_multi_transformer_encoder_pass"));
} }
TEST(MultiDevicesFusedMultiTransformerEncoderPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x) c_identity -> c_identity0_out
// (x) c_identity -> c_identity1_out
// (x) c_identity -> c_identity2_out
// (c_identity0_out, weights_0) matmul_v2 -> matmul_out0
// (c_identity1_out, weights_1) matmul_v2 -> matmul_out1
// (c_identity2_out, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
// (eltadd_0) reshape2 -> reshape_0
// (eltadd_1) reshape2 -> reshape_1
// (eltadd_2) reshape2 -> reshape_2
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_0) scale -> scale_0
// (scale_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_all_reduce -> c_all_reduce_out
// (c_all_reduce_out) elementwise_add -> eltadd_linear
// (eltadd_linear) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> layer_norm_out
// (layer_norm_out) c_identity -> ffn_c_identity_out
// (ffn_c_identity_out, ffn_matmul0_w)matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out
// (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1
// (layer_norm_out, ffn_eltadd1) elementwise_add -> ffn_output
// (ffn_output, scale, bias) layer_norm -> ffn_layer_norm_out
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* c_identity0_out = layers.c_identity(x);
auto* c_identity1_out = layers.c_identity(x);
auto* c_identity2_out = layers.c_identity(x);
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 1024}, true);
auto* weights_1 = layers.data("weights1", {1024, 1024}, true);
auto* weights_2 = layers.data("weights2", {1024, 1024}, true);
auto* matmul_out_0 =
layers.matmul_v2(c_identity0_out, weights_0, nullptr, false, false);
auto* matmul_out_1 =
layers.matmul_v2(c_identity1_out, weights_1, nullptr, false, false);
auto* matmul_out_2 =
layers.matmul_v2(c_identity2_out, weights_2, nullptr, false, false);
auto* b0 = layers.data("bias_0", {1024}, true);
auto* b1 = layers.data("bias_1", {1024}, true);
auto* b2 = layers.data("bias_2", {1024}, true);
auto* elementwise_out_0 =
layers.elementwise_add(matmul_out_0, b0, nullptr, 2);
auto* elementwise_out_1 =
layers.elementwise_add(matmul_out_1, b1, nullptr, 2);
auto* elementwise_out_2 =
layers.elementwise_add(matmul_out_2, b2, nullptr, 2);
std::vector<int> shape = {1, 128, 16, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true);
auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto* transpose_1 = layers.transpose2(reshape_1, axis, true);
auto* transpose_2 = layers.transpose2(reshape_2, axis, true);
// q scale
auto* scale_q = layers.scale(transpose_0, 0.125, 0, false);
// MHA: QK matmul
auto* matmul_qk =
layers.matmul_v2(scale_q, transpose_1, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 1, 1, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
// MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(softmax_qk, transpose_2);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("bias_l", {1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, false);
auto* c_allreduce_out = layers.c_allreduce_sum(linear_matmut_out);
auto* linear_eltadd_out =
layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2);
auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
// post LayerNorm
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(attention_out, ln_scale, ln_bias)[0];
auto* ffn_c_identity_out = layers.c_identity(ln_out);
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true);
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_c_identity_out, ffn_weights0, nullptr, false, false);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
auto* ffn_matmul1_out =
layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, false);
auto* ffn_allreduce_out = layers.c_allreduce_sum(ffn_matmul1_out);
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_allreduce_out, ffn_bias1, nullptr, 2);
auto* ffn_out = layers.elementwise_add(ln_out, ffn_eltadd1_out);
// FFN: post LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
layers.layer_norm(ffn_out, ffn_ln_scale, ffn_ln_bias)[0];
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
graph->Set("enable_int8", new bool(false));
auto pass = PassRegistry::Instance().Get(
"multi_devices_fused_multi_transformer_encoder_pass");
if (pass.get() == nullptr)
LOG(INFO)
<< "get multi_devices_fused_multi_transformer_encoder_pass failed";
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(num_nodes_before,
num_nodes_after + 70,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_pass, The "
"node num in graph "
"should be %d, but the result is %d",
num_nodes_before - 70,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder pass, "
"there should be one fused_multi_transformer op, "
"but the result is %d",
num_fused_nodes_after));
}
TEST(MultiDevicesFusedMultiTransformerEncoderPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible(
"multi_devices_fused_multi_transformer_encoder_pass"));
}
TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
// inputs operator output // inputs operator output
// -------------------------------------------------------------------- // --------------------------------------------------------------------
...@@ -292,7 +465,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -292,7 +465,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* matmul_qk = layers.matmul_v2(split_q, split_k, nullptr, false, true); auto* matmul_qk = layers.matmul_v2(split_q, split_k, nullptr, false, true);
auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false); auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 1, 1, 128}, true);
auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk); auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
...@@ -447,7 +620,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -447,7 +620,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* matmul_qk = layers.matmul_v2(split_q, split_k, nullptr, false, true); auto* matmul_qk = layers.matmul_v2(split_q, split_k, nullptr, false, true);
auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false); auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 1, 1, 128}, true);
auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk); auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
...@@ -542,4 +715,5 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, ...@@ -542,4 +715,5 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
USE_PASS(fused_multi_transformer_encoder_pass); USE_PASS(fused_multi_transformer_encoder_pass);
USE_PASS(fused_multi_transformer_encoder_fuse_qkv_pass); USE_PASS(fused_multi_transformer_encoder_fuse_qkv_pass);
USE_PASS(multi_devices_fused_multi_transformer_encoder_pass);
USE_PASS(multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass); USE_PASS(multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass);
...@@ -179,6 +179,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{ ...@@ -179,6 +179,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"fused_multi_transformer_decoder_pass", "fused_multi_transformer_decoder_pass",
"fused_multi_transformer_encoder_fuse_qkv_pass", "fused_multi_transformer_encoder_fuse_qkv_pass",
"fused_multi_transformer_decoder_fuse_qkv_pass", "fused_multi_transformer_decoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_encoder_pass",
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass",
"fuse_multi_transformer_layer_pass", "fuse_multi_transformer_layer_pass",
...@@ -228,6 +229,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -228,6 +229,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"fused_multi_transformer_decoder_pass", // "fused_multi_transformer_decoder_pass", //
"fused_multi_transformer_encoder_fuse_qkv_pass", // "fused_multi_transformer_encoder_fuse_qkv_pass", //
"fused_multi_transformer_decoder_fuse_qkv_pass", // "fused_multi_transformer_decoder_fuse_qkv_pass", //
"multi_devices_fused_multi_transformer_encoder_pass", //
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", // "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", //
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", // "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", //
"fuse_multi_transformer_layer_pass", // "fuse_multi_transformer_layer_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册