diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 9fa7527827ae4ab80239a939066a9f87eee8658c..944dbd07bab5c489b119697fb322bd263b7cc365 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -245,6 +245,8 @@ if(WITH_XPU) pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(fused_multi_transformer_cachekv_layout_trans_pass inference DIR + xpu DEPS ${XPU_PASS_DEPS}) endif() cc_library( @@ -528,4 +530,8 @@ if(WITH_XPU) test_stack_fuse_pass SRCS xpu/stack_fuse_pass_test.cc DEPS stack_fuse_pass) + cc_test( + test_fused_multi_transformer_cachekv_layout_trans_pass + SRCS xpu/fused_multi_transformer_cachekv_layout_trans_pass_test.cc + DEPS fused_multi_transformer_cachekv_layout_trans_pass) endif() diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index cc4033f7f5a54585546092f870fbf49f38c424bc..6611be59fcc800b1fa643f00732a651fa0815e66 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -62,6 +62,7 @@ static const std::vector xpu_support_subgraph_passes = { "embedding_with_eltwise_add_xpu_fuse_pass", "multi_encoder_xpu_fuse_pass", "multi_encoder_xpu_slice_fuse_pass", + "fused_multi_transformer_cachekv_layout_trans_pass", "one_beam_size_fuse_pass", "stack_fuse_pass", "fused_multi_transformer_xpu_quant_pass", diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.cc b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..993b5a055280d869dea26e9a27b4b5c860717495 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.cc @@ -0,0 +1,202 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.h" +#include +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct FusedMultiTransformerFillConstantPattern : public PatternBase { + FusedMultiTransformerFillConstantPattern(PDPattern* pattern, + const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(fill_constant); + PATTERN_DECL_NODE(fused_multi_transformer); + // declare variable node's name + PATTERN_DECL_NODE(fill_constant_out); +}; // struct FusedMultiTransformerFillConstantPattern + +FusedMultiTransformerFillConstantPattern:: + FusedMultiTransformerFillConstantPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* fill_constant = pattern->NewNode(fill_constant_repr()) + ->assert_is_op("fill_constant") + ->assert_has_n_inputs(5) + ->assert_more([](Node* node) { + return node->Op()->GetAttrIfExists( + "friendly_device_type") != "xpu"; + }); + auto* fill_constant_out = pattern->NewNode(fill_constant_out_repr()) + ->assert_is_op_output("fill_constant", "Out"); + auto* fused_multi_transformer = + pattern->NewNode(fused_multi_transformer_repr()) + ->assert_is_op("fused_multi_transformer"); + + fill_constant->LinksTo({fill_constant_out}); + fused_multi_transformer->LinksFrom({fill_constant_out}); +} + +struct FusedMultiTransformerGatherPattern : public PatternBase { + FusedMultiTransformerGatherPattern(PDPattern* pattern, + const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(fused_multi_transformer); + PATTERN_DECL_NODE(gather); + // declare variable node's name + PATTERN_DECL_NODE(gather_in); + PATTERN_DECL_NODE(gather_out); +}; // struct FusedMultiTransformerGatherPattern + +FusedMultiTransformerGatherPattern::FusedMultiTransformerGatherPattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* gather_in = + pattern->NewNode(gather_in_repr())->assert_is_op_input("gather", "X"); + auto* gather = pattern->NewNode(gather_repr()) + ->assert_is_op("gather") + ->assert_more([](Node* node) { + return node->Op()->GetAttrIfExists("axis") == 1; + }); + auto* gather_out = + pattern->NewNode(gather_out_repr())->assert_is_op_output("gather", "Out"); + auto* fused_multi_transformer = + pattern->NewNode(fused_multi_transformer_repr()) + ->assert_is_op("fused_multi_transformer"); + + gather->LinksFrom({gather_in}).LinksTo({gather_out}); + fused_multi_transformer->LinksFrom({gather_out}); +} +} // namespace patterns + +void FusedMultiTransformerCacheKVLayoutTransPass::FillConstantReshapePass( + ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + GraphPatternDetector gpd; + patterns::FusedMultiTransformerFillConstantPattern pattern( + gpd.mutable_pattern(), name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle FillConstantReshapePass"; + GET_IR_NODE(fused_multi_transformer); + GET_IR_NODE(fill_constant); + GET_IR_NODE(fill_constant_out); + auto cachekv_names = fused_multi_transformer->Op()->Input("CacheKV"); + if (std::count(cachekv_names.begin(), + cachekv_names.end(), + fill_constant_out->Name()) == 0) + return; + + auto fill_constant_input_names = + fill_constant->Op()->Input("ShapeTensorList"); + auto fill_constant_trans_input_names = + std::vector{fill_constant_input_names[0], + fill_constant_input_names[3], + fill_constant_input_names[1], + fill_constant_input_names[2], + fill_constant_input_names[4]}; + fill_constant->Op()->SetInput("ShapeTensorList", + fill_constant_trans_input_names); + + auto fill_constant_output_shape = fill_constant_out->Var()->GetShape(); + fill_constant_out->Var()->SetShape({fill_constant_output_shape[0], + fill_constant_output_shape[3], + fill_constant_output_shape[1], + fill_constant_output_shape[2], + fill_constant_output_shape[4]}); + + fused_multi_transformer->Op()->SetAttr("friendly_device_type", + std::string("xpu")); + fill_constant->Op()->SetAttr("friendly_device_type", std::string("xpu")); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void FusedMultiTransformerCacheKVLayoutTransPass::GatherReshapePass( + ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + GraphPatternDetector gpd; + patterns::FusedMultiTransformerGatherPattern pattern(gpd.mutable_pattern(), + name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle GatherReshapePass"; + GET_IR_NODE(gather); + GET_IR_NODE(fused_multi_transformer); + GET_IR_NODE(gather_in); + GET_IR_NODE(gather_out); + auto cachekv_names = fused_multi_transformer->Op()->Input("CacheKV"); + if (std::count(cachekv_names.begin(), + cachekv_names.end(), + gather_out->Name()) == 0) + return; + + auto gather_in_shape = gather_in->Var()->GetShape(); + auto gather_out_shape = gather_out->Var()->GetShape(); + gather_in->Var()->SetShape({gather_in_shape[0], + gather_in_shape[3], + gather_in_shape[1], + gather_in_shape[2], + gather_in_shape[4]}); + gather_out->Var()->SetShape({gather_out_shape[0], + gather_out_shape[3], + gather_out_shape[1], + gather_out_shape[2], + gather_out_shape[4]}); + gather->Op()->SetAttr("axis", 2); + fused_multi_transformer->Op()->SetAttr("friendly_device_type", + std::string("xpu")); + + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void FusedMultiTransformerCacheKVLayoutTransPass::ApplyImpl( + ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + FillConstantReshapePass(graph); + GatherReshapePass(graph); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS( + fused_multi_transformer_cachekv_layout_trans_pass, + paddle::framework::ir::FusedMultiTransformerCacheKVLayoutTransPass); diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.h b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..cb87317a76e6a040c97738f36f4707b1a1191b43 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.h @@ -0,0 +1,79 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +class FusedMultiTransformerCacheKVLayoutTransPass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + /* + Origin subgraph: + (ShapeTensorList: [d0,d1,d2,d3,d4]) + | + fill_constant + | + fused_multi_transformer + + Fused subgraph: + (ShapeTensorList: [d0,d3,d1,d2,d4]) + | + fill_constant + | + fused_multi_transformer + */ + void FillConstantReshapePass(ir::Graph* graph) const; + + /* + Origin subgraph: + (gather_x: [d0,d1,d2,d3,d4]) + | + gather(axis=1) + | + fused_multi_transformer + + Fused subgraph: + (gather_x: [d0,d3,d1,d2,d4]) + | + gather(axis=2) + | + fused_multi_transformer + */ + void GatherReshapePass(ir::Graph* graph) const; + + const std::string name_scope_{ + "fused_multi_transformer_cachekv_layout_trans_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass_test.cc b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ec5dba201163fffbbc6fa4182537c73a6e25083c --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass_test.cc @@ -0,0 +1,125 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +VarDesc* Data(paddle::framework::BlockDesc* block, + std::string name, + std::vector shape = {}, + bool is_persistable = false, + proto::VarType::Type data_type = proto::VarType::FP32) { + auto* var = block->Var(name); + var->SetType(proto::VarType::LOD_TENSOR); + var->SetDataType(data_type); + var->SetShape(shape); + var->SetPersistable(is_persistable); + return var; +} + +VarDesc* fill_constant(BlockDesc* block, std::vector shapes) { + VarDesc* out = Data(block, shapes[0]->Name() + "_out"); + OpDesc* op = block->AppendOp(); + op->SetType("fill_constant"); + std::vector shape_names; + for (auto shape : shapes) { + shape_names.push_back(shape->Name()); + } + op->SetInput("ShapeTensorList", {shape_names}); + op->SetOutput("Out", {out->Name()}); + return out; +} + +TEST(FillConstantReshapePass, basic) { + paddle::framework::ProgramDesc program; + auto* block = program.MutableBlock(0); + auto* shape0 = Data(block, "shape0"); + auto* shape1 = Data(block, "shape1"); + auto* shape2 = Data(block, "shape2"); + auto* shape3 = Data(block, "shape3"); + auto* shape4 = Data(block, "shape4"); + auto* shape5 = Data(block, "shape5"); + auto* shape6 = Data(block, "shape6"); + auto* shape7 = Data(block, "shape7"); + auto* shape8 = Data(block, "shape8"); + auto* shape9 = Data(block, "shape9"); + auto* fill0 = fill_constant(block, {shape0, shape1, shape2, shape3, shape4}); + fill0->SetShape({1, 2, 3, 4, 5}); + auto* fill1 = fill_constant(block, {shape5, shape6, shape7, shape8, shape9}); + fill1->SetShape({1, 2, 3, 4, 5}); + OpDesc* fused_multi_transformer = block->AppendOp(); + fused_multi_transformer->SetType("fused_multi_transformer"); + fused_multi_transformer->SetInput("CacheKV", {fill0->Name(), fill1->Name()}); + + std::unique_ptr graph(new ir::Graph(program)); + auto pass = PassRegistry::Instance().Get( + "fused_multi_transformer_cachekv_layout_trans_pass"); + pass->Apply(graph.get()); + auto fills = GetOpNodes(graph, "fill_constant"); + auto fill0_in_names = fills[0]->Op()->Input("ShapeTensorList"); + std::vector expect_fill0_in_names{ + "shape0", "shape3", "shape1", "shape2", "shape4"}; + PADDLE_ENFORCE_EQ(fill0_in_names, + expect_fill0_in_names, + platform::errors::PreconditionNotMet( + "fill_constant name should be updated.")); + auto fill1_in_names = fills[1]->Op()->Input("ShapeTensorList"); + std::vector expect_fill1_in_names{ + "shape5", "shape8", "shape6", "shape7", "shape9"}; + PADDLE_ENFORCE_EQ(fill1_in_names, + expect_fill1_in_names, + platform::errors::PreconditionNotMet( + "fill_constant name should be updated.")); +} + +TEST(GatherReshapePass, basic) { + Layers layers; + auto* gather0_x = layers.data("gather0_x", {2, 1, 24, 512, 64}); + auto* gather0_index = layers.data("gather0_index", {1}); + auto* gather0_out = layers.gather(gather0_x, gather0_index, 1); + gather0_out->SetShape({2, 1, 24, 512, 64}); + auto* gather1_x = layers.data("gather1_x", {2, 1, 24, 512, 64}); + auto* gather1_index = layers.data("gather1_index", {1}); + auto* gather1_out = layers.gather(gather1_x, gather1_index, 1); + gather1_out->SetShape({2, 1, 24, 512, 64}); + auto* block = layers.Block(); + OpDesc* fused_multi_transformer = block->AppendOp(); + fused_multi_transformer->SetType("fused_multi_transformer"); + fused_multi_transformer->SetInput("CacheKV", + {gather0_out->Name(), gather1_out->Name()}); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get( + "fused_multi_transformer_cachekv_layout_trans_pass"); + pass->Apply(graph.get()); + auto gathers = GetOpNodes(graph, "gather"); + for (auto* gather : gathers) { + PADDLE_ENFORCE_EQ( + gather->Op()->GetAttrIfExists("axis"), + 2, + platform::errors::PreconditionNotMet( + "gather's axis attr should be updated to 2 by pass.")); + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(fused_multi_transformer_cachekv_layout_trans_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 1be90972b924bd26e870ff8b74736c9fcbac0122..e7c24272b81c5568fa39e2d25df2d734e5d70a57 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -519,6 +519,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "embedding_with_eltwise_add_xpu_fuse_pass", "multi_encoder_xpu_fuse_pass", "multi_encoder_xpu_slice_fuse_pass", + "fused_multi_transformer_cachekv_layout_trans_pass", "one_beam_size_fuse_pass", "delete_cast_op_pass", "stack_fuse_pass", diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 5c0aa3b8e89fdbfbed834e00bad405146304828a..f775cedce8c114831e660d3b9a8591c00803443e 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -291,31 +291,26 @@ void FusedMultiTransformerXpuInferMeta( std::vector cache_kv_out) { auto x_dim = x.dims(); auto y_dim = qkvw[0]->dims(); - PADDLE_ENFORCE_EQ( - x_dim.size(), - 3, - phi::errors::InvalidArgument("The dimensions of x must be 3" - "(batch_size, seq_len, dim_embed)," - "but received dimensions of" - "Input is [%d]", - x_dim.size())); + PADDLE_ENFORCE_EQ(x_dim.size(), + 3, + phi::errors::InvalidArgument( + "The dimensions of x must be 3(batch_size, seq_len, " + "dim_embed), but received dimensions of Input is [%d]", + x_dim.size())); PADDLE_ENFORCE_EQ( y_dim.size(), 4, - phi::errors::InvalidArgument("The dimensions of qkv_weight must be 4" - "(3, num_head, dim_head, dim_embed)," - "but received dimensions of" - "Input is [%d]", - y_dim.size())); + phi::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4(3, num_head, dim_head, " + "dim_embed), but received dimensions of qkv_weight is [%d]", + y_dim.size())); PADDLE_ENFORCE_EQ( x_dim[2], trans_qkvw ? y_dim[3] : y_dim[0], phi::errors::InvalidArgument( - "ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is " - "true) or y_dim[0](trans_qkvw is false)" - "must be equal. But received: the shape " - "of input x = [%s], and the shape of " - "input qkv_weight = [%s]", + "The dimension of x_dim[2] and y_dim[3](trans_qkvw is true) or " + "y_dim[0](trans_qkvw is false) must be equal, but received: the " + "shape of input x = [%s], and the shape of input qkv_weight = [%s]", x_dim, y_dim)); if (cache_kv.size() > 0) { @@ -330,27 +325,27 @@ void FusedMultiTransformerXpuInferMeta( phi::errors::InvalidArgument( "The first dim of CacheKV must be 2, but got %d", c_dim[0])); // 2 - PADDLE_ENFORCE_EQ(c_dim[1], - x_dim[0], - phi::errors::InvalidArgument( - "The second dim of CacheKV must be equal with " - "batch size %d, but got %d", - x_dim[0], - c_dim[1])); // batch_size - PADDLE_ENFORCE_EQ(c_dim[2], - trans_qkvw ? y_dim[1] : y_dim[2], - phi::errors::InvalidArgument( - "The third dim of CacheKV must be equal with num " - "head %d, but got %d", - trans_qkvw ? y_dim[1] : y_dim[2], - c_dim[2])); // num_head - PADDLE_ENFORCE_EQ(c_dim[4], - trans_qkvw ? y_dim[2] : y_dim[3], - phi::errors::InvalidArgument( - "The fifth dim of CacheKV must be equal with head " - "size %d, but got %d", - trans_qkvw ? y_dim[2] : y_dim[3], - c_dim[4])); // head_size + PADDLE_ENFORCE_EQ( + c_dim[2], + x_dim[0], + phi::errors::InvalidArgument("The third dim of CacheKV must be equal " + "with batch size %d, but got %d", + x_dim[0], + c_dim[2])); // batch_size + PADDLE_ENFORCE_EQ( + c_dim[3], + trans_qkvw ? y_dim[1] : y_dim[2], + phi::errors::InvalidArgument("The fourth dim of CacheKV must be equal " + "with num head %d, but got %d", + trans_qkvw ? y_dim[1] : y_dim[2], + c_dim[3])); // num_head + PADDLE_ENFORCE_EQ( + c_dim[4], + trans_qkvw ? y_dim[2] : y_dim[3], + phi::errors::InvalidArgument("The fifth dim of CacheKV must be equal " + "with head size %d, but got %d", + trans_qkvw ? y_dim[2] : y_dim[3], + c_dim[4])); // head_size } out->set_dims(x_dim);