diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc index 4e2bca2ae2a9703704b7fa2d2ed565713cc4d551..b730d46ab7c5f9d481651e674f7491d18b2d582d 100644 --- a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc @@ -62,34 +62,7 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, fused_multi_transformer_name, "Out"); if (is_decoder) { - auto shape_repr = - PDNodeName(name_scope_, repr_, id_, "shape_" + std::to_string(i)); - node_reprs["shape_" + std::to_string(i)] = shape_repr; - auto* shape = pattern->NewNode(shape_repr)->assert_is_op("shape"); - - auto shape_out_repr = - PDNodeName(name_scope_, repr_, id_, "shape_out_" + std::to_string(i)); - node_reprs["shape_out_" + std::to_string(i)] = shape_out_repr; - auto* shape_out = - pattern->NewNode(shape_out_repr)->assert_is_op_output("shape", "Out"); - - shape->LinksFrom({src_mask}).LinksTo({shape_out}); - - auto slice_repr = - PDNodeName(name_scope_, repr_, id_, "slice_" + std::to_string(i)); - node_reprs["slice_" + std::to_string(i)] = slice_repr; - auto* slice = pattern->NewNode(slice_repr)->assert_is_op("slice"); - - auto slice_out_repr = - PDNodeName(name_scope_, repr_, id_, "slice_out_" + std::to_string(i)); - node_reprs["slice_out_" + std::to_string(i)] = slice_out_repr; - auto* slice_out = - pattern->NewNode(slice_out_repr)->assert_is_op_output("slice", "Out"); - - slice->LinksFrom({shape_out}).LinksTo({slice_out}); - - fused_multi_transformer->LinksFrom({x0, src_mask, slice_out}) - .LinksTo({out}); + fused_multi_transformer->LinksFrom({x0, src_mask}).LinksTo({out}); } else { auto cache_kv_repr = PDNodeName(name_scope_, repr_, id_, "cache_kv_" + std::to_string(i)); @@ -187,10 +160,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, std::vector fuse_op_nodes; std::vector out_nodes; - std::vector unused_node_prefixes = { - "shape_", "shape_out_", "slice_", "slice_out_"}; - std::vector unused_nodes; - std::vector fuse_op_descs; std::vector fuse_op_input_var_name_maps; std::vector fuse_op_output_var_name_maps; @@ -219,14 +188,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, fill_op_node->Op()->SetInput("Input", {x0->Name()}); IR_NODE_UNLINK(out_nodes[i - 1], fill_op_node); IR_NODE_LINK_TO(x0, fill_op_node); - } else if (is_decoder && i != 0) { - for (const auto& unused_node_prefix : unused_node_prefixes) { - PDNode* unused_pdnode = - multi_layer_pattern.PatternBase::pattern->RetrieveNode( - node_reprs[unused_node_prefix + std::to_string(i)]); - Node* unused_node = subgraph.at(unused_pdnode); - unused_nodes.push_back(unused_node); - } } } @@ -293,10 +254,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, std::unordered_set marked_fuse_op_nodes( fuse_op_nodes.begin() + 1, fuse_op_nodes.end()); - if (is_decoder) { - marked_fuse_op_nodes.insert(unused_nodes.begin(), unused_nodes.end()); - } - GraphSafeRemoveNodes(graph, marked_fuse_op_nodes); ++fusion_count; }; diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc index 42c699195beb91c6e7e29d0231e6a1aee1e75e98..2d93758f177d2831bb002c161262edd2532d88fe 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc @@ -1146,35 +1146,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, auto cache_kv_name = "cache_kv" + std::to_string(layer_idx); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name}); - VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx)); - shape_out_desc.SetDataType(proto::VarType::INT32); - shape_out_desc.SetPersistable(false); - auto* shape_out = graph->CreateVarNode(&shape_out_desc); - - OpDesc shape_op_desc(layer_norm->Op()->Block()); - shape_op_desc.SetType("shape"); - shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()}); - shape_op_desc.SetOutput("Out", {shape_out->Name()}); - auto* shape_op = graph->CreateOpNode(&shape_op_desc); - - VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx)); - slice_out_desc.SetDataType(proto::VarType::INT32); - slice_out_desc.SetPersistable(false); - auto* slice_out = graph->CreateVarNode(&slice_out_desc); - - OpDesc slice_op_desc(layer_norm->Op()->Block()); - slice_op_desc.SetType("slice"); - slice_op_desc.SetInput("Input", {shape_out->Name()}); - slice_op_desc.SetOutput("Out", {slice_out->Name()}); - std::vector axes = {0}; - std::vector starts = {3}; - std::vector ends = {4}; - slice_op_desc.SetAttr("axes", axes); - slice_op_desc.SetAttr("starts", starts); - slice_op_desc.SetAttr("ends", ends); - auto* slice_op = graph->CreateOpNode(&slice_op_desc); - - fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()}); + fused_multi_transformer_op_desc.SetInput("TimeStep", {"slice_out.0"}); // Out Linear input fused_multi_transformer_op_desc.SetInput("OutLinearW", @@ -1219,12 +1191,42 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer); IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer); - // TimeStep link - IR_NODE_LINK_TO(eltadd_qk_b, shape_op); - IR_NODE_LINK_TO(shape_op, shape_out); - IR_NODE_LINK_TO(shape_out, slice_op); - IR_NODE_LINK_TO(slice_op, slice_out); - IR_NODE_LINK_TO(slice_out, fused_multi_transformer) + if (layer_idx == 0) { + VarDesc shape_out_desc("shape_out.0"); + shape_out_desc.SetDataType(proto::VarType::INT32); + shape_out_desc.SetPersistable(false); + auto* shape_out = graph->CreateVarNode(&shape_out_desc); + + OpDesc shape_op_desc(layer_norm->Op()->Block()); + shape_op_desc.SetType("shape"); + shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()}); + shape_op_desc.SetOutput("Out", {shape_out->Name()}); + auto* shape_op = graph->CreateOpNode(&shape_op_desc); + + VarDesc slice_out_desc("slice_out.0"); + slice_out_desc.SetDataType(proto::VarType::INT32); + slice_out_desc.SetPersistable(false); + auto* slice_out = graph->CreateVarNode(&slice_out_desc); + + OpDesc slice_op_desc(layer_norm->Op()->Block()); + slice_op_desc.SetType("slice"); + slice_op_desc.SetInput("Input", {shape_out->Name()}); + slice_op_desc.SetOutput("Out", {slice_out->Name()}); + std::vector axes = {0}; + std::vector starts = {3}; + std::vector ends = {4}; + slice_op_desc.SetAttr("axes", axes); + slice_op_desc.SetAttr("starts", starts); + slice_op_desc.SetAttr("ends", ends); + auto* slice_op = graph->CreateOpNode(&slice_op_desc); + + // TimeStep link + IR_NODE_LINK_TO(eltadd_qk_b, shape_op); + IR_NODE_LINK_TO(shape_op, shape_out); + IR_NODE_LINK_TO(shape_out, slice_op); + IR_NODE_LINK_TO(slice_op, slice_out); + IR_NODE_LINK_TO(slice_out, fused_multi_transformer) + } IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer); IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer); @@ -1789,35 +1791,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( auto cache_kv_name = "cache_kv" + std::to_string(layer_idx); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name}); - VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx)); - shape_out_desc.SetDataType(proto::VarType::INT32); - shape_out_desc.SetPersistable(false); - auto* shape_out = graph->CreateVarNode(&shape_out_desc); - - OpDesc shape_op_desc(layer_norm->Op()->Block()); - shape_op_desc.SetType("shape"); - shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()}); - shape_op_desc.SetOutput("Out", {shape_out->Name()}); - auto* shape_op = graph->CreateOpNode(&shape_op_desc); - - VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx)); - slice_out_desc.SetDataType(proto::VarType::INT32); - slice_out_desc.SetPersistable(false); - auto* slice_out = graph->CreateVarNode(&slice_out_desc); - - OpDesc slice_op_desc(layer_norm->Op()->Block()); - slice_op_desc.SetType("slice"); - slice_op_desc.SetInput("Input", {shape_out->Name()}); - slice_op_desc.SetOutput("Out", {slice_out->Name()}); - std::vector axes = {0}; - std::vector starts = {3}; - std::vector ends = {4}; - slice_op_desc.SetAttr("axes", axes); - slice_op_desc.SetAttr("starts", starts); - slice_op_desc.SetAttr("ends", ends); - auto* slice_op = graph->CreateOpNode(&slice_op_desc); - - fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()}); + fused_multi_transformer_op_desc.SetInput("TimeStep", {"slice_out.0"}); // Out Linear input fused_multi_transformer_op_desc.SetInput("OutLinearW", @@ -1862,12 +1836,42 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer); IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer); - // TimeStep link - IR_NODE_LINK_TO(eltadd_qk_b, shape_op); - IR_NODE_LINK_TO(shape_op, shape_out); - IR_NODE_LINK_TO(shape_out, slice_op); - IR_NODE_LINK_TO(slice_op, slice_out); - IR_NODE_LINK_TO(slice_out, fused_multi_transformer) + if (layer_idx == 0) { + VarDesc shape_out_desc("shape_out.0"); + shape_out_desc.SetDataType(proto::VarType::INT32); + shape_out_desc.SetPersistable(false); + auto* shape_out = graph->CreateVarNode(&shape_out_desc); + + OpDesc shape_op_desc(layer_norm->Op()->Block()); + shape_op_desc.SetType("shape"); + shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()}); + shape_op_desc.SetOutput("Out", {shape_out->Name()}); + auto* shape_op = graph->CreateOpNode(&shape_op_desc); + + VarDesc slice_out_desc("slice_out.0"); + slice_out_desc.SetDataType(proto::VarType::INT32); + slice_out_desc.SetPersistable(false); + auto* slice_out = graph->CreateVarNode(&slice_out_desc); + + OpDesc slice_op_desc(layer_norm->Op()->Block()); + slice_op_desc.SetType("slice"); + slice_op_desc.SetInput("Input", {shape_out->Name()}); + slice_op_desc.SetOutput("Out", {slice_out->Name()}); + std::vector axes = {0}; + std::vector starts = {3}; + std::vector ends = {4}; + slice_op_desc.SetAttr("axes", axes); + slice_op_desc.SetAttr("starts", starts); + slice_op_desc.SetAttr("ends", ends); + auto* slice_op = graph->CreateOpNode(&slice_op_desc); + + // TimeStep link + IR_NODE_LINK_TO(eltadd_qk_b, shape_op); + IR_NODE_LINK_TO(shape_op, shape_out); + IR_NODE_LINK_TO(shape_out, slice_op); + IR_NODE_LINK_TO(slice_op, slice_out); + IR_NODE_LINK_TO(slice_out, fused_multi_transformer) + } IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer); IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer); @@ -2405,35 +2409,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( auto cache_kv_name = "cache_kv" + std::to_string(layer_idx); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name}); - VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx)); - shape_out_desc.SetDataType(proto::VarType::INT32); - shape_out_desc.SetPersistable(false); - auto* shape_out = graph->CreateVarNode(&shape_out_desc); - - OpDesc shape_op_desc(layer_norm->Op()->Block()); - shape_op_desc.SetType("shape"); - shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()}); - shape_op_desc.SetOutput("Out", {shape_out->Name()}); - auto* shape_op = graph->CreateOpNode(&shape_op_desc); - - VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx)); - slice_out_desc.SetDataType(proto::VarType::INT32); - slice_out_desc.SetPersistable(false); - auto* slice_out = graph->CreateVarNode(&slice_out_desc); - - OpDesc slice_op_desc(layer_norm->Op()->Block()); - slice_op_desc.SetType("slice"); - slice_op_desc.SetInput("Input", {shape_out->Name()}); - slice_op_desc.SetOutput("Out", {slice_out->Name()}); - std::vector axes = {0}; - std::vector starts = {3}; - std::vector ends = {4}; - slice_op_desc.SetAttr("axes", axes); - slice_op_desc.SetAttr("starts", starts); - slice_op_desc.SetAttr("ends", ends); - auto* slice_op = graph->CreateOpNode(&slice_op_desc); - - fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()}); + fused_multi_transformer_op_desc.SetInput("TimeStep", {"slice_out.0"}); // Out Linear input fused_multi_transformer_op_desc.SetInput("OutLinearW", @@ -2483,12 +2459,42 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer); IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer); - // TimeStep link - IR_NODE_LINK_TO(eltadd_qk_b, shape_op); - IR_NODE_LINK_TO(shape_op, shape_out); - IR_NODE_LINK_TO(shape_out, slice_op); - IR_NODE_LINK_TO(slice_op, slice_out); - IR_NODE_LINK_TO(slice_out, fused_multi_transformer) + if (layer_idx == 0) { + VarDesc shape_out_desc("shape_out.0"); + shape_out_desc.SetDataType(proto::VarType::INT32); + shape_out_desc.SetPersistable(false); + auto* shape_out = graph->CreateVarNode(&shape_out_desc); + + OpDesc shape_op_desc(layer_norm->Op()->Block()); + shape_op_desc.SetType("shape"); + shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()}); + shape_op_desc.SetOutput("Out", {shape_out->Name()}); + auto* shape_op = graph->CreateOpNode(&shape_op_desc); + + VarDesc slice_out_desc("slice_out.0"); + slice_out_desc.SetDataType(proto::VarType::INT32); + slice_out_desc.SetPersistable(false); + auto* slice_out = graph->CreateVarNode(&slice_out_desc); + + OpDesc slice_op_desc(layer_norm->Op()->Block()); + slice_op_desc.SetType("slice"); + slice_op_desc.SetInput("Input", {shape_out->Name()}); + slice_op_desc.SetOutput("Out", {slice_out->Name()}); + std::vector axes = {0}; + std::vector starts = {3}; + std::vector ends = {4}; + slice_op_desc.SetAttr("axes", axes); + slice_op_desc.SetAttr("starts", starts); + slice_op_desc.SetAttr("ends", ends); + auto* slice_op = graph->CreateOpNode(&slice_op_desc); + + // TimeStep link + IR_NODE_LINK_TO(eltadd_qk_b, shape_op); + IR_NODE_LINK_TO(shape_op, shape_out); + IR_NODE_LINK_TO(shape_out, slice_op); + IR_NODE_LINK_TO(slice_op, slice_out); + IR_NODE_LINK_TO(slice_out, fused_multi_transformer) + } IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer); IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index a1980a8ba5005ecd25ab186e128b517c982834f4..19fd7279b9677ae5738490f6f22346d996aeb078 100755 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -177,6 +177,7 @@ const std::vector kGpuLowerPrecisionPasses{ "fused_multi_transformer_decoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", + "fuse_multi_transformer_layer_pass", "gpu_cpu_map_matmul_v2_to_mul_pass", "gpu_cpu_map_matmul_v2_to_matmul_pass", "fc_fuse_pass",