From 1e3245a8be96c97b86f54abb86554c9fc5c824d5 Mon Sep 17 00:00:00 2001 From: RichardWooSJTU <37864677+RichardWooSJTU@users.noreply.github.com> Date: Thu, 10 Nov 2022 11:12:17 +0800 Subject: [PATCH] Fuse multi transformer layer pass (#47541) * add fuse_multi_transformer_layer_pass --- paddle/fluid/framework/ir/CMakeLists.txt | 5 + .../ir/fuse_multi_transformer_layer_pass.cc | 325 ++++++++++++++++++ .../ir/fuse_multi_transformer_layer_pass.h | 60 ++++ ...use_multi_transformer_layer_pass_tester.cc | 175 ++++++++++ .../fused_multi_transformer_decoder_pass.cc | 3 + .../fused_multi_transformer_encoder_pass.cc | 3 + paddle/fluid/framework/ir/pass.cc | 2 +- paddle/fluid/framework/ir/pass.h | 4 + .../fluid/framework/ir/pass_tester_helper.h | 113 ++++++ .../inference/api/paddle_pass_builder.cc | 1 + 10 files changed, 690 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc create mode 100644 paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h create mode 100644 paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 7108ed158a..5e46fd92bf 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -107,6 +107,7 @@ pass_library(skip_layernorm_fuse_pass base) pass_library(multihead_matmul_fuse_pass inference) pass_library(fused_multi_transformer_encoder_pass inference) pass_library(fused_multi_transformer_decoder_pass inference) +pass_library(fuse_multi_transformer_layer_pass inference) pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(yolo_box_fuse_pass inference) @@ -326,6 +327,10 @@ cc_test( test_fused_multi_transformer_decoder_pass SRCS fused_multi_transformer_decoder_pass_tester.cc DEPS fused_multi_transformer_decoder_pass) +cc_test( + test_fuse_multi_transformer_layer_pass + SRCS fuse_multi_transformer_layer_pass_tester.cc + DEPS fuse_multi_transformer_layer_pass) cc_test( test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc new file mode 100644 index 0000000000..4e2bca2ae2 --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc @@ -0,0 +1,325 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h" + +#include + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +std::unordered_map +MultiTransformerLayerPattern::operator()(bool enable_int8, + int num_fused_op, + bool is_decoder) { + std::string fused_multi_transformer_name = + enable_int8 ? "fused_multi_transformer_int8" : "fused_multi_transformer"; + + std::unordered_map node_reprs; + + // x0 and src_mask is unqiue input of subgraph + auto* x0 = pattern->NewNode(x0_repr()); + x0->assert_is_op_input(fused_multi_transformer_name, "X")->AsInput(); + auto* src_mask = pattern->NewNode(src_mask_repr()); + src_mask->assert_is_op_input(fused_multi_transformer_name, "SrcMask") + ->AsInput(); + + for (int i = 0; i < num_fused_op; ++i) { + auto fuse_op_repr = + PDNodeName(name_scope_, repr_, id_, "fuse_op_" + std::to_string(i)); + node_reprs["fuse_op_" + std::to_string(i)] = fuse_op_repr; + auto* fused_multi_transformer = + pattern->NewNode(fuse_op_repr) + ->assert_is_op(fused_multi_transformer_name); + + auto out_repr = + PDNodeName(name_scope_, repr_, id_, "out_" + std::to_string(i)); + node_reprs["out_" + std::to_string(i)] = out_repr; + auto* out = pattern->NewNode(out_repr)->assert_is_op_output( + fused_multi_transformer_name, "Out"); + + if (is_decoder) { + auto shape_repr = + PDNodeName(name_scope_, repr_, id_, "shape_" + std::to_string(i)); + node_reprs["shape_" + std::to_string(i)] = shape_repr; + auto* shape = pattern->NewNode(shape_repr)->assert_is_op("shape"); + + auto shape_out_repr = + PDNodeName(name_scope_, repr_, id_, "shape_out_" + std::to_string(i)); + node_reprs["shape_out_" + std::to_string(i)] = shape_out_repr; + auto* shape_out = + pattern->NewNode(shape_out_repr)->assert_is_op_output("shape", "Out"); + + shape->LinksFrom({src_mask}).LinksTo({shape_out}); + + auto slice_repr = + PDNodeName(name_scope_, repr_, id_, "slice_" + std::to_string(i)); + node_reprs["slice_" + std::to_string(i)] = slice_repr; + auto* slice = pattern->NewNode(slice_repr)->assert_is_op("slice"); + + auto slice_out_repr = + PDNodeName(name_scope_, repr_, id_, "slice_out_" + std::to_string(i)); + node_reprs["slice_out_" + std::to_string(i)] = slice_out_repr; + auto* slice_out = + pattern->NewNode(slice_out_repr)->assert_is_op_output("slice", "Out"); + + slice->LinksFrom({shape_out}).LinksTo({slice_out}); + + fused_multi_transformer->LinksFrom({x0, src_mask, slice_out}) + .LinksTo({out}); + } else { + auto cache_kv_repr = + PDNodeName(name_scope_, repr_, id_, "cache_kv_" + std::to_string(i)); + node_reprs["cache_kv_" + std::to_string(i)] = cache_kv_repr; + auto* cache_kv = pattern->NewNode(cache_kv_repr); + cache_kv->assert_is_op_input(fused_multi_transformer_name, "CacheKV"); + cache_kv->AsInput(); + + auto fill_const_op_repr = + PDNodeName(name_scope_, repr_, id_, "fill_op_" + std::to_string(i)); + node_reprs["fill_op_" + std::to_string(i)] = fill_const_op_repr; + auto fill_const_op = pattern->NewNode(fill_const_op_repr) + ->assert_is_op("fill_constant_batch_size_like"); + + fused_multi_transformer->LinksFrom({x0, src_mask, cache_kv}) + .LinksTo({out}); + fill_const_op->LinksFrom({x0}).LinksTo({cache_kv}); + } + x0 = out; + } + x0->AsOutput(); + return node_reprs; +} +} // namespace patterns + +inline void MergeInput(OpDesc* op, + const std::vector& input_name_maps, + const std::string& input_name) { + std::vector tmp = input_name_maps[0].at(input_name); + for (size_t i = 1; i < input_name_maps.size(); ++i) { + tmp.insert(tmp.end(), + input_name_maps[i].at(input_name).begin(), + input_name_maps[i].at(input_name).end()); + } + op->SetInput(input_name, tmp); +} + +template +inline void MergeAttrs(const std::vector& ops, + const std::string& attr_name) { + std::vector res; + for (size_t i = 0; i < ops.size(); ++i) { + auto scale_vec = + PADDLE_GET_CONST(std::vector, ops[i]->GetAttr(attr_name)); + res.insert(res.end(), scale_vec.begin(), scale_vec.end()); + } + ops[0]->SetAttr(attr_name, res); +} + +int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // TODO(wufeisheng): Get enable_int8 attr from graph after + // fused_multi_transformer pass with int8 merged + bool enable_int8 = false; + + int num_fuse_op = 0; + bool is_decoder = false; + + if (graph->Has(kFusedMultiTransformerEncoderFusionCount)) { + num_fuse_op = graph->Get(kFusedMultiTransformerEncoderFusionCount); + is_decoder = false; + } else if (graph->Has(kFusedMultiTransformerDecoderFusionCount)) { + num_fuse_op = graph->Get(kFusedMultiTransformerDecoderFusionCount); + is_decoder = true; + } + if (num_fuse_op == 0) { + VLOG(4) << "fuse_multi_transformer_layer_pass will be skipped " + "cause num_fuse_op is not been set or set to 0"; + return 0; + } + if (!is_decoder) { + VLOG(4) << "fuse_multi_transformer_layer_pass will match encoder pattern"; + } else { + VLOG(4) << "fuse_multi_transformer_layer_pass will match decoder pattern"; + } + + patterns::MultiTransformerLayerPattern multi_layer_pattern(pattern, + name_scope); + auto node_reprs = multi_layer_pattern(enable_int8, num_fuse_op, is_decoder); + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + /////////////////// + //// Get nodes //// + /////////////////// + + GET_IR_NODE_FROM_SUBGRAPH(src_mask, src_mask, multi_layer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(x0, x0, multi_layer_pattern); + + std::vector fuse_op_nodes; + std::vector out_nodes; + + std::vector unused_node_prefixes = { + "shape_", "shape_out_", "slice_", "slice_out_"}; + std::vector unused_nodes; + + std::vector fuse_op_descs; + std::vector fuse_op_input_var_name_maps; + std::vector fuse_op_output_var_name_maps; + + for (int i = 0; i < num_fuse_op; ++i) { + PDNode* fuse_op_pdnode = + multi_layer_pattern.PatternBase::pattern->RetrieveNode( + node_reprs["fuse_op_" + std::to_string(i)]); + Node* fuse_op_node = subgraph.at(fuse_op_pdnode); + fuse_op_nodes.push_back(fuse_op_node); + fuse_op_descs.push_back(fuse_op_node->Op()); + fuse_op_input_var_name_maps.emplace_back(fuse_op_node->Op()->Inputs()); + fuse_op_output_var_name_maps.emplace_back(fuse_op_node->Op()->Outputs()); + + PDNode* out_pdnode = + multi_layer_pattern.PatternBase::pattern->RetrieveNode( + node_reprs["out_" + std::to_string(i)]); + out_nodes.push_back(subgraph.at(out_pdnode)); + + // fill_const op use x0 as input + if (!is_decoder && i != 0) { + PDNode* fill_op_pdnode = + multi_layer_pattern.PatternBase::pattern->RetrieveNode( + node_reprs["fill_op_" + std::to_string(i)]); + Node* fill_op_node = subgraph.at(fill_op_pdnode); + fill_op_node->Op()->SetInput("Input", {x0->Name()}); + IR_NODE_UNLINK(out_nodes[i - 1], fill_op_node); + IR_NODE_LINK_TO(x0, fill_op_node); + } else if (is_decoder && i != 0) { + for (const auto& unused_node_prefix : unused_node_prefixes) { + PDNode* unused_pdnode = + multi_layer_pattern.PatternBase::pattern->RetrieveNode( + node_reprs[unused_node_prefix + std::to_string(i)]); + Node* unused_node = subgraph.at(unused_pdnode); + unused_nodes.push_back(unused_node); + } + } + } + + /////////////// + //// Merge //// + /////////////// + + // Merge inputs + std::vector inputs_names = {"CacheKV", + "FFN1Bias", + "FFN1Weight", + "FFN2Bias", + "FFN2Weight", + "FFNLnBias", + "FFNLnScale", + "LnBias", + "LnScale", + "OutLinearBias", + "OutLinearW", + "QKVBias", + "QKVW"}; + + for (const auto& input_name : inputs_names) { + MergeInput(fuse_op_descs[0], fuse_op_input_var_name_maps, input_name); + } + + // Merge outputs + fuse_op_descs[0]->SetOutput( + "Out", fuse_op_output_var_name_maps[num_fuse_op - 1]["Out"]); + auto& merged_cache_kv_out_names = + fuse_op_output_var_name_maps[0]["CacheKVOut"]; + for (int i = 1; i < num_fuse_op; ++i) { + const auto& out_var_names = fuse_op_output_var_name_maps[i]["CacheKVOut"]; + merged_cache_kv_out_names.insert(merged_cache_kv_out_names.end(), + out_var_names.begin(), + out_var_names.end()); + } + fuse_op_descs[0]->SetOutput("CacheKVOut", merged_cache_kv_out_names); + + //////////////// + //// ReLink //// + //////////////// + // Before relink, out nodes (0 -> num_layer-1) should be removed + std::unordered_set marked_out_nodes(out_nodes.begin(), + out_nodes.end() - 1); + GraphSafeRemoveNodes(graph, marked_out_nodes); + + // Relink all input nodes of fused_multi_transformer ops to the first op + auto& merged_inputs = fuse_op_nodes[0]->inputs; + for (int i = 1; i < num_fuse_op; ++i) { + merged_inputs.insert(merged_inputs.end(), + fuse_op_nodes[i]->inputs.begin(), + fuse_op_nodes[i]->inputs.end()); + } + + // Relink fuse op -> out + IR_NODE_UNLINK(fuse_op_nodes[num_fuse_op - 1], out_nodes[num_fuse_op - 1]); + IR_NODE_LINK_TO(fuse_op_nodes[0], out_nodes[num_fuse_op - 1]); + + ///////////////////////////// + //// Delete unused nodes //// + ///////////////////////////// + // Delete fused_multi_transformer op expect for the first one + std::unordered_set marked_fuse_op_nodes( + fuse_op_nodes.begin() + 1, fuse_op_nodes.end()); + + if (is_decoder) { + marked_fuse_op_nodes.insert(unused_nodes.begin(), unused_nodes.end()); + } + + GraphSafeRemoveNodes(graph, marked_fuse_op_nodes); + ++fusion_count; + }; + + gpd(graph, handler); + return fusion_count; +} + +void FuseMultiTransformerLayerPass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::Fatal("During the fuse_multi_transformer_layer pass, " + "The scope should not be null.")); + int fusion_count = BuildFusion(graph, name_scope_, scope); + + AddStatis(fusion_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fuse_multi_transformer_layer_pass, + paddle::framework::ir::FuseMultiTransformerLayerPass); diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h new file mode 100644 index 0000000000..339cc6815e --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct MultiTransformerLayerPattern : public PatternBase { + MultiTransformerLayerPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "fuse_multi_transformer_layer") {} + + std::unordered_map operator()( + bool enable_int8, int num_fused_op = 1, bool is_decoder = false); + + PATTERN_DECL_NODE(src_mask); + PATTERN_DECL_NODE(x0); +}; + +} // namespace patterns + +class FuseMultiTransformerLayerPass : public FusePassBase { + public: + FuseMultiTransformerLayerPass() {} + virtual ~FuseMultiTransformerLayerPass() {} + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"fuse_multi_transformer_layer"}; + + private: + int BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc new file mode 100644 index 0000000000..72635d1c95 --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc @@ -0,0 +1,175 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" + +#define DEF_INPUT_DATA \ + Layers layers; \ + int num_layers = 3; \ + auto* x = layers.data("x", {1, 128, 1024}); \ + auto* src_mask = layers.data("src_mask", {1, 16, 128, 128}); \ + auto* ln_scale = layers.data("ln_scale", {1024}, true); \ + auto* ln_bias = layers.data("ln_bias", {1024}, true); \ + auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); \ + auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); \ + auto* qkv_w = layers.data("qkv_w", {3, 16, 64, 1024}, true); \ + auto* out_linear_w = layers.data("out_linear_w", {1024, 1024}, true); \ + auto* ffn1_w = layers.data("ffn1_w", {1024, 4096}, true); \ + auto* ffn2_w = layers.data("ffn2_w", {4096, 1024}, true); \ + auto* qkv_bias = layers.data("qkv_bias", {3072}, true); \ + auto* out_linear_bias = layers.data("out_linear_bias", {1024}, true); \ + auto* ffn1_bias = layers.data("ffn1_bias", {4096}, true); \ + auto* ffn2_bias = layers.data("ffn2_bias", {1024}, true); + +namespace paddle { +namespace framework { +namespace ir { + +void AddVarToScope(Scope* param_scope, + const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "ln_scale", {1024}); + AddVarToScope(param_scope, "ln_bias", {1024}); + AddVarToScope(param_scope, "ffn_ln_scale", {1024}); + AddVarToScope(param_scope, "ffn_ln_bias", {1024}); + + AddVarToScope(param_scope, "qkv_w", {3, 16, 64, 1024}); + AddVarToScope(param_scope, "out_linear_w", {1024, 1024}); + AddVarToScope(param_scope, "ffn1_w", {1024, 4096}); + AddVarToScope(param_scope, "ffn2_w", {4096, 1024}); + AddVarToScope(param_scope, "qkv_bias", {3072}); + AddVarToScope(param_scope, "out_linear_bias", {1024}); + AddVarToScope(param_scope, "ffn1_bias", {4096}); + AddVarToScope(param_scope, "ffn2_bias", {1024}); + + return param_scope; +} +TEST(FuseMultiTransformerLayerPass, encoder_fp) { + DEF_INPUT_DATA + + // Layers + for (int i = 0; i < num_layers; ++i) { + auto* cache_kv = layers.fill_constant_batch_size_like( + x, + static_cast(proto::VarType::FP32), + 0, + 1, + {2, -1, 16, 1024, 64}, + 0); + auto* out = layers.fused_multi_transformer(x, + cache_kv, + src_mask, + qkv_w, + qkv_bias, + out_linear_w, + out_linear_bias, + ffn1_w, + ffn1_bias, + ffn2_w, + ffn2_bias, + ln_scale, + ln_bias, + ffn_ln_scale, + ffn_ln_bias, + 0.1, + 1e-12); + + x = out; + } + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(num_layers)); + + auto pass = PassRegistry::Instance().Get("fuse_multi_transformer_layer_pass"); + if (pass.get() == nullptr) + LOG(INFO) << "get fuse_multi_transformer_layer_pass failed"; + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); + + PADDLE_ENFORCE_EQ( + num_nodes_after, + 1, + platform::errors::InvalidArgument( + "After the fuse_multi_transformer_layer_pass, " + "The node num in graph should be 1, but the result is %d", + num_nodes_after)); +} +TEST(FuseMultiTransformerLayerPass, decoder_fp) { + DEF_INPUT_DATA + + x = layers.data("x", {1, 1, 1024}); + auto* cache_kv = layers.data("cache_kv", {2, 1, 16, 1024, 64}, true); + src_mask = layers.data("src_mask", {1, 16, 1, 129}); + + // Layers + for (int i = 0; i < num_layers; ++i) { + auto* shape_out = layers.shape(src_mask); + auto* time_stamp = layers.slice(shape_out, {0}, {3}, {4}); + auto* out = layers.fused_multi_transformer(x, + cache_kv, + src_mask, + qkv_w, + qkv_bias, + out_linear_w, + out_linear_bias, + ffn1_w, + ffn1_bias, + ffn2_w, + ffn2_bias, + ln_scale, + ln_bias, + ffn_ln_scale, + ffn_ln_bias, + 0.1, + 1e-12, + time_stamp); + + x = out; + } + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto param_scope = CreateParamScope(); + AddVarToScope(param_scope, "cache_kv", {2, 1, 16, 1024, 64}); + graph->Set("__param_scope__", param_scope); + + graph->Set(kFusedMultiTransformerDecoderFusionCount, new int(num_layers)); + + auto pass = PassRegistry::Instance().Get("fuse_multi_transformer_layer_pass"); + if (pass.get() == nullptr) + LOG(INFO) << "get fuse_multi_transformer_layer_pass failed"; + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); + + PADDLE_ENFORCE_EQ( + num_nodes_after, + 1, + platform::errors::InvalidArgument( + "After the fuse_multi_transformer_layer_pass, " + "The node num in graph should be 1, but the result is %d", + num_nodes_after)); +} +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(fuse_multi_transformer_layer_pass); diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc index ef896e9c7e..42c699195b 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc @@ -1565,6 +1565,7 @@ void FusedMultiTransformerDecoderPass::ApplyImpl(Graph* graph) const { int fusion_count = BuildFusion(graph, name_scope_, scope); if (fusion_count > 0) { graph->Set(kFusedMultiTransformerDecoderPass, new bool(true)); + graph->Set(kFusedMultiTransformerDecoderFusionCount, new int(fusion_count)); } AddStatis(fusion_count); } @@ -2178,6 +2179,7 @@ void FusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(Graph* graph) const { int fusion_count = BuildFusion(graph, name_scope_, scope); if (fusion_count > 0) { graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true)); + graph->Set(kFusedMultiTransformerDecoderFusionCount, new int(fusion_count)); } AddStatis(fusion_count); } @@ -2833,6 +2835,7 @@ void MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::ApplyImpl( int fusion_count = BuildFusion(graph, name_scope_, scope); if (fusion_count > 0) { graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true)); + graph->Set(kFusedMultiTransformerDecoderFusionCount, new int(fusion_count)); } AddStatis(fusion_count); } diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc index 8738779f5e..0503b3a0a3 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc @@ -1728,6 +1728,7 @@ void FusedMultiTransformerEncoderPass::ApplyImpl(Graph* graph) const { int fusion_count = BuildFusion(graph, name_scope_, scope); if (fusion_count > 0) { graph->Set(kFusedMultiTransformerEncoderPass, new bool(true)); + graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count)); } AddStatis(fusion_count); } @@ -2380,6 +2381,7 @@ void FusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(Graph* graph) const { int fusion_count = BuildFusion(graph, name_scope_, scope); if (fusion_count > 0) { graph->Set(kFusedMultiTransformerEncoderFuseQKVPass, new bool(true)); + graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count)); } AddStatis(fusion_count); } @@ -3076,6 +3078,7 @@ void MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::ApplyImpl( if (fusion_count > 0) { graph->Set(kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass, new bool(true)); + graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count)); } AddStatis(fusion_count); } diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 74ad71f37d..4ad9318399 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -46,7 +46,7 @@ static const std::vector support_subgraph_passes = { "fused_multi_transformer_decoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", -}; + "fuse_multi_transformer_layer_pass"}; Graph *Pass::Apply(Graph *graph) const { VLOG(10) << "start to apply pass " << Type() << " to graph"; diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 2ed753cdeb..e0315f0b5b 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -59,6 +59,10 @@ constexpr char kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass[] = "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass_flag"; constexpr char kMultiDevicesFusedMultiTransformerDecoderFuseQKVPass[] = "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass_flag"; +constexpr char kFusedMultiTransformerEncoderFusionCount[] = + "fused_multi_transformer_encoder_fusion_count"; +constexpr char kFusedMultiTransformerDecoderFusionCount[] = + "fused_multi_transformer_decoder_fusion_count"; constexpr char kPrelnEmbEltwiseLayernormPass[] = "preln_embedding_eltwise_layernorm_fuse_pass_flag"; diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 3cce19e10c..48f8cb37d6 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -528,6 +528,119 @@ struct Layers { return out; } + VarDesc* shape(VarDesc* input) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("shape"); + op->SetInput("Input", {input->Name()}); + op->SetOutput("Out", {out->Name()}); + return out; + } + + VarDesc* slice(VarDesc* input, + std::vector axes, + std::vector starts, + std::vector ends) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("slice"); + op->SetInput("Input", {input->Name()}); + op->SetOutput("Out", {out->Name()}); + op->SetAttr("axes", axes); + op->SetAttr("starts", starts); + op->SetAttr("ends", ends); + return out; + } + + VarDesc* fill_constant_batch_size_like(VarDesc* x, + int dtype, + int input_dim_idx, + int output_dim_idx, + std::vector shape, + float value) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("fill_constant_batch_size_like"); + op->SetInput("Input", {x->Name()}); + op->SetAttr("dtype", dtype); + op->SetAttr("input_dim_idx", input_dim_idx); + op->SetAttr("output_dim_idx", output_dim_idx); + op->SetAttr("shape", shape); + op->SetAttr("value", value); + op->SetOutput("Out", {out->Name()}); + return out; + } + + VarDesc* fused_multi_transformer(VarDesc* x, + VarDesc* cache_kv, + VarDesc* src_mask, + VarDesc* qkv_w, + VarDesc* qkv_bias, + VarDesc* out_linear_w, + VarDesc* out_linear_bias, + VarDesc* ffn1_w, + VarDesc* ffn1_bias, + VarDesc* ffn2_w, + VarDesc* ffn2_bias, + VarDesc* ln_scale, + VarDesc* ln_bias, + VarDesc* ffn_ln_scale, + VarDesc* ffn_ln_bias, + float epsilon, + float dropout_rate, + VarDesc* time_stamp = nullptr, + VarDesc* qkv_out_scale = nullptr, + VarDesc* out_linear_out_scale = nullptr, + VarDesc* ffn1_out_scale = nullptr, + VarDesc* ffn2_out_scale = nullptr, + std::vector qkv_in_scale = {}, + std::vector out_linear_in_scale = {}, + std::vector ffn1_in_scale = {}, + std::vector ffn2_in_scale = {}) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + std::string op_type = qkv_out_scale ? "fused_multi_transformer_int8" + : "fused_multi_transformer"; + op->SetType(op_type); + op->SetInput("X", {x->Name()}); + op->SetInput("CacheKV", {cache_kv->Name()}); + op->SetInput("SrcMask", {src_mask->Name()}); + op->SetInput("QKVW", {qkv_w->Name()}); + op->SetInput("QKVBias", {qkv_bias->Name()}); + op->SetInput("OutLinearW", {out_linear_w->Name()}); + op->SetInput("OutLinearBias", {out_linear_bias->Name()}); + op->SetInput("FFN1Weight", {ffn1_w->Name()}); + op->SetInput("FFN1Bias", {ffn1_bias->Name()}); + op->SetInput("FFN2Weight", {ffn2_w->Name()}); + op->SetInput("FFN2Bias", {ffn2_bias->Name()}); + op->SetInput("LnScale", {ln_scale->Name()}); + op->SetInput("LnBias", {ln_bias->Name()}); + op->SetInput("FFNLnScale", {ffn_ln_scale->Name()}); + op->SetInput("FFNLnBias", {ffn_ln_bias->Name()}); + op->SetAttr("pre_layer_norm", true); + op->SetAttr("is_test", true); + op->SetAttr("dropout_implementation", "upscale_in_train"); + op->SetAttr("dropout_rate", dropout_rate); + op->SetAttr("epsilon", epsilon); + op->SetOutput("Out", {out->Name()}); + + if (time_stamp) { + op->SetInput("TimeStep", {time_stamp->Name()}); + } + + if (qkv_out_scale) { + op->SetInput("QKVOutScale", {qkv_out_scale->Name()}); + op->SetInput("OutLinearOutScale", {out_linear_out_scale->Name()}); + op->SetInput("FFN1OutScale", {ffn1_out_scale->Name()}); + op->SetInput("FFN2OutScale", {ffn2_out_scale->Name()}); + op->SetAttr("qkv_in_scale", qkv_in_scale); + op->SetAttr("out_linear_in_scale", out_linear_in_scale); + op->SetAttr("ffn1_in_scale", ffn1_in_scale); + op->SetAttr("ffn2_in_scale", ffn2_in_scale); + } + return out; + } + void backward(std::vector targets) { // This function is designed to simulate the structure of training program, // but is constructed differently as the actual program. diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 21fd3bc0cf..aa699b4458 100755 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -212,6 +212,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "fused_multi_transformer_decoder_fuse_qkv_pass", // "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", // "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", // + "fuse_multi_transformer_layer_pass", // "gpu_cpu_squeeze2_matmul_fuse_pass", // "gpu_cpu_reshape2_matmul_fuse_pass", // "gpu_cpu_flatten2_matmul_fuse_pass", // -- GitLab