diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 944dbd07bab5c489b119697fb322bd263b7cc365..46bac61bda13b4f26fc8270bf0973f71965eb41c 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -242,7 +242,7 @@ if(WITH_XPU) pass_library(one_beam_size_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(delete_isolated_node_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) - pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS + pass_library(fused_multi_transformer_xpu_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(fused_multi_transformer_cachekv_layout_trans_pass inference DIR @@ -519,9 +519,9 @@ if(WITH_XPU) SRCS xpu/delete_isolated_node_pass_test.cc DEPS delete_isolated_node_pass) cc_test( - test_fused_multi_transformer_xpu_quant_pass - SRCS xpu/fused_multi_transformer_xpu_quant_pass_tester.cc - DEPS fused_multi_transformer_xpu_quant_pass) + test_fused_multi_transformer_xpu_pass + SRCS xpu/fused_multi_transformer_xpu_pass_tester.cc + DEPS fused_multi_transformer_xpu_pass) cc_test( test_one_beam_size_fuse_pass SRCS xpu/one_beam_size_fuse_pass_test.cc diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 6611be59fcc800b1fa643f00732a651fa0815e66..a84c9b84e9466645a265a9fd6535b5602734a609 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -65,7 +65,7 @@ static const std::vector xpu_support_subgraph_passes = { "fused_multi_transformer_cachekv_layout_trans_pass", "one_beam_size_fuse_pass", "stack_fuse_pass", - "fused_multi_transformer_xpu_quant_pass", + "fused_multi_transformer_xpu_pass", "fc_xpu_fuse_pass", "link_xpu_op_max_pass", }; diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass.cc similarity index 81% rename from paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc rename to paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass.cc index b9868a81135d4d84a3353d62788b3816cfa4da3c..5676465e7131abac66c18232cd359006e9c77fea 100644 --- a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass.cc @@ -39,6 +39,32 @@ namespace framework { namespace ir { namespace patterns { +struct FusedMultiTransformerAssignPattern : public PatternBase { + FusedMultiTransformerAssignPattern(PDPattern* pattern, + const std::string& name_scope); + // declare operator node's name + PATTERN_DECL_NODE(assign); + // declare variable node's name + PATTERN_DECL_NODE(assign_out); +}; + +FusedMultiTransformerAssignPattern::FusedMultiTransformerAssignPattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* assign = + pattern->NewNode(assign_repr()) + ->assert_is_op("assign") + ->assert_more([&](Node* node) { + auto pre_op_nodes = node->inputs[0]->inputs; + return pre_op_nodes.size() == 1 && + pre_op_nodes[0]->Op()->Type() == "fused_multi_transformer"; + }); + auto* assign_out = + pattern->NewNode(assign_out_repr())->assert_is_op_output("assign", "Out"); + + assign->LinksTo({assign_out}); +} + struct FusedMultiTransformerPattern : public PatternBase { FusedMultiTransformerPattern(PDPattern* pattern, const std::string& name_scope, @@ -47,7 +73,6 @@ struct FusedMultiTransformerPattern : public PatternBase { bool with_time_step, bool with_seq_lengths, bool with_src_mask); - // declare operator node's name PATTERN_DECL_NODE(fused_mt); // declare variable node's name @@ -234,44 +259,106 @@ FusedMultiTransformerPattern::FusedMultiTransformerPattern( } // namespace patterns /* -1. transpose and quantify the weights of fused_multi_transformer op from fp32 to +1. Remove gather and assign op to reduce graphics memory consumption +2. transpose and quantify the weights of fused_multi_transformer op from fp32 to int16 */ -class FusedMultiTransformerXPUQuantPass : public FusePassBase { +class FusedMultiTransformerXPUPass : public FusePassBase { protected: void ApplyImpl(ir::Graph* graph) const override; private: - int ApplyImpl(ir::Graph* graph, - bool with_pre_caches, - bool with_rotary_pos_emb, - bool with_time_step, - bool with_seq_lengths, - bool with_src_mask) const; - - const std::string name_scope_{"fused_multi_transformer_xpu_quant_pass"}; + /* + Origin subgraph: + fused_multi_transformer + | | | + assign assign ... + | | | + gather gather ... + + Fused subgraph: + fused_multi_transformer + */ + void RemoveAssignGather(ir::Graph* graph) const; + + /* + Origin subgraph: + fused_multi_transformer + + Fused subgraph: + fused_multi_transformer_xpu + */ + int FusedMultiTransformerXPUQuant(ir::Graph* graph, + bool with_pre_caches, + bool with_rotary_pos_emb, + bool with_time_step, + bool with_seq_lengths, + bool with_src_mask) const; + + const std::string name_scope_{"fused_multi_transformer_xpu_pass"}; }; -void FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph) const { +void FusedMultiTransformerXPUPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); Init(name_scope_, graph); - VLOG(3) << "in FusedMultiTransformerXPUQuantPass::ApplyImpl"; + VLOG(3) << "in FusedMultiTransformerXPUPass::ApplyImpl"; int found_subgraph_count = 0; + RemoveAssignGather(graph); for (bool with_time_step : {true, false}) { - found_subgraph_count += - ApplyImpl(graph, false, false, with_time_step, false, true); + found_subgraph_count += FusedMultiTransformerXPUQuant( + graph, false, false, with_time_step, false, true); } AddStatis(found_subgraph_count); } -int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, - bool with_pre_caches, - bool with_rotary_pos_emb, - bool with_time_step, - bool with_seq_lengths, - bool with_src_mask) const { +void FusedMultiTransformerXPUPass::RemoveAssignGather(ir::Graph* graph) const { + // detect assign + gather + GraphPatternDetector gpd; + patterns::FusedMultiTransformerAssignPattern pattern(gpd.mutable_pattern(), + name_scope_); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(1) << "handle RemoveAssignGather"; + GET_IR_NODE(assign); + GET_IR_NODE(assign_out); + // Assign_out may not link to gather, so we find gather by input name. + auto next_ops = FindOpNodeByInputName(graph, assign_out->Name()); + if (next_ops.size() != 1 || next_ops[0]->Name() != "gather") return; + auto* gather = next_ops[0]; + + // "assign_out" is used in multi blocks. "assign_out" should be reserved. + auto* gather_index = gather->inputs[0]; + auto* assign_in = assign->inputs[0]; + auto* fused_multi_transformer = assign_in->inputs[0]; + fused_multi_transformer->Op()->Rename(assign_in->Name(), + assign_out->Name()); + fused_multi_transformer->Op()->SetInput("gather_index", + gather->Op()->Input("Index")); + fused_multi_transformer->Op()->SetAttr("gather_axis", + gather->Op()->GetAttr("axis")); + IR_NODE_LINK_TO(gather_index, fused_multi_transformer); + IR_NODE_LINK_TO(fused_multi_transformer, assign_out); + + std::unordered_set delete_nodes{assign, assign_in, gather}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +int FusedMultiTransformerXPUPass::FusedMultiTransformerXPUQuant( + ir::Graph* graph, + bool with_pre_caches, + bool with_rotary_pos_emb, + bool with_time_step, + bool with_seq_lengths, + bool with_src_mask) const { GraphPatternDetector gpd; patterns::FusedMultiTransformerPattern pattern(gpd.mutable_pattern(), name_scope_, @@ -286,7 +373,7 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { - VLOG(4) << "handle FusedMultiTransformerXPUQuantPass fuse"; + VLOG(4) << "handle FusedMultiTransformerXPUQuant"; GET_IR_NODE(x); GET_IR_NODE(ln_scale); @@ -459,6 +546,13 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, if (name_caches.count("CacheKV") > 0) { fused_mt_xpu_op_desc->SetInput("cache_kv", name_caches.at("CacheKV")); } + if (name_caches.count("gather_index") > 0) { + fused_mt_xpu_op_desc->SetInput("gather_index", + name_caches.at("gather_index")); + } + if (!fused_mt_xpu_op_desc->HasAttr("gather_axis")) { + fused_mt_xpu_op_desc->SetAttr("gather_axis", 0); + } if (pre_caches) { fused_mt_xpu_op_desc->SetInput("pre_caches", name_caches.at("PreCaches")); } @@ -529,5 +623,5 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, } // namespace framework } // namespace paddle -REGISTER_PASS(fused_multi_transformer_xpu_quant_pass, - paddle::framework::ir::FusedMultiTransformerXPUQuantPass); +REGISTER_PASS(fused_multi_transformer_xpu_pass, + paddle::framework::ir::FusedMultiTransformerXPUPass); diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass_tester.cc b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass_tester.cc similarity index 69% rename from paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass_tester.cc rename to paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass_tester.cc index cdcc20cb2f76577b1d123746bbf9c2649b1d5bcf..9251387f867eb0f1c11684950a2199593896d7c8 100644 --- a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass_tester.cc +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass_tester.cc @@ -64,7 +64,62 @@ Scope* CreateParamScope() { return param_scope; } -TEST(FusedMultiTransformerXPUQuantPass, context_stage) { +VarDesc* Data(paddle::framework::BlockDesc* block, + std::string name, + std::vector shape = {}, + bool is_persistable = false, + proto::VarType::Type data_type = proto::VarType::FP32) { + auto* var = block->Var(name); + var->SetType(proto::VarType::LOD_TENSOR); + var->SetDataType(data_type); + var->SetShape(shape); + var->SetPersistable(is_persistable); + return var; +} + +TEST(RemoveAssignGather, basic) { + paddle::framework::ProgramDesc program; + auto* block = program.MutableBlock(0); + + auto* x = Data(block, "fused_multi_transformer_x", {1, 1, 1536}); + auto* cache_kv = + Data(block, "fused_multi_transformer_cache_kv", {2, 1, 24, 512, 64}); + OpDesc* fused_multi_transformer_op = block->AppendOp(); + fused_multi_transformer_op->SetType("fused_multi_transformer"); + fused_multi_transformer_op->SetInput("X", {x->Name()}); + fused_multi_transformer_op->SetInput("CacheKV", {cache_kv->Name()}); + fused_multi_transformer_op->SetOutput("CacheKVOut", {cache_kv->Name()}); + + auto* assign_out = Data(block, "assign_out", cache_kv->GetShape()); + OpDesc* assign_op = block->AppendOp(); + assign_op->SetType("assign"); + assign_op->SetInput("X", {cache_kv->Name()}); + assign_op->SetOutput("Out", {assign_out->Name()}); + + OpDesc* gather_op = block->AppendOp(); + auto gather_index = Data(block, "gather_index", {10}); + gather_op->SetType("gather"); + gather_op->SetInput("X", {assign_out->Name()}); + gather_op->SetInput("Index", {gather_index->Name()}); + gather_op->SetAttr("axis", {1}); + gather_op->SetOutput("Out", {cache_kv->Name()}); + + std::unique_ptr graph(new ir::Graph(program)); + auto pass = PassRegistry::Instance().Get("fused_multi_transformer_xpu_pass"); + pass->Apply(graph.get()); + auto assign_num = GetNumOpNodes(graph, "assign"); + auto gather_num = GetNumOpNodes(graph, "gather"); + PADDLE_ENFORCE_EQ(assign_num, + 0, + platform::errors::PreconditionNotMet( + "assign op should be removed from the graph.")); + PADDLE_ENFORCE_EQ(gather_num, + 0, + platform::errors::PreconditionNotMet( + "gather op should be removed from the graph.")); +} + +TEST(FusedMultiTransformerXPUPass, context_stage) { DEF_INPUT_DATA auto* cache_kv = layers.fill_constant_batch_size_like( @@ -95,10 +150,9 @@ TEST(FusedMultiTransformerXPUQuantPass, context_stage) { std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); - auto pass = - PassRegistry::Instance().Get("fused_multi_transformer_xpu_quant_pass"); + auto pass = PassRegistry::Instance().Get("fused_multi_transformer_xpu_pass"); if (pass.get() == nullptr) { - LOG(INFO) << "get fused_multi_transformer_xpu_quant_pass failed"; + LOG(INFO) << "get fused_multi_transformer_xpu_pass failed"; } graph.reset(pass->Apply(graph.release())); @@ -114,7 +168,7 @@ TEST(FusedMultiTransformerXPUQuantPass, context_stage) { num_nodes_after)); } -TEST(FusedMultiTransformerXPUQuantPass, decoder_stage) { +TEST(FusedMultiTransformerXPUPass, decoder_stage) { DEF_INPUT_DATA auto* cache_kv = layers.fill_constant_batch_size_like( @@ -146,10 +200,9 @@ TEST(FusedMultiTransformerXPUQuantPass, decoder_stage) { std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); - auto pass = - PassRegistry::Instance().Get("fused_multi_transformer_xpu_quant_pass"); + auto pass = PassRegistry::Instance().Get("fused_multi_transformer_xpu_pass"); if (pass.get() == nullptr) { - LOG(INFO) << "get fused_multi_transformer_xpu_quant_pass failed"; + LOG(INFO) << "get fused_multi_transformer_xpu_pass failed"; } graph.reset(pass->Apply(graph.release())); @@ -169,4 +222,4 @@ TEST(FusedMultiTransformerXPUQuantPass, decoder_stage) { } // namespace framework } // namespace paddle -USE_PASS(fused_multi_transformer_xpu_quant_pass); +USE_PASS(fused_multi_transformer_xpu_pass); diff --git a/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.cc index 9ab6cd031b5c72b40b4af989eb736e0407e08314..f50ee81eaaefdec7a0b1d2733d9c905767472b7c 100644 --- a/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.cc @@ -259,23 +259,6 @@ bool OnlyOneBeamSearchAndOneBeamSize(ir::Graph* graph) { beam_search_nodes[0]->Op()->GetAttrIfExists("beam_size") == 1; } -std::vector FindOpNodeByInputName(Graph* graph, - const std::string& var_name) { - std::vector ret; - for (auto* node : graph->Nodes()) { - if (!node->IsOp()) continue; - auto inputs = node->Op()->Inputs(); - for (auto input : inputs) { - auto in_names = input.second; - if (std::count(in_names.begin(), in_names.end(), var_name) > 0) { - ret.push_back(node); - break; - } - } - } - return ret; -} - void OneBeamSizeFusePass::RemoveAssignGather(ir::Graph* graph) const { // detect assign + gather GraphPatternDetector gpd; diff --git a/paddle/fluid/framework/ir/xpu/pass_utils.cc b/paddle/fluid/framework/ir/xpu/pass_utils.cc index aaa117d363a5fd6d4ac4d42482df929cb4e90d90..eeb0e23e19ecde75bc6db4d95c45eade95cacd18 100644 --- a/paddle/fluid/framework/ir/xpu/pass_utils.cc +++ b/paddle/fluid/framework/ir/xpu/pass_utils.cc @@ -71,6 +71,23 @@ Node* FindNodeWithName(Graph* graph, std::string name) { return nullptr; } +std::vector FindOpNodeByInputName(Graph* graph, + const std::string& var_name) { + std::vector ret; + for (auto* node : graph->Nodes()) { + if (!node->IsOp()) continue; + auto inputs = node->Op()->Inputs(); + for (auto input : inputs) { + auto in_names = input.second; + if (std::count(in_names.begin(), in_names.end(), var_name) > 0) { + ret.push_back(node); + break; + } + } + } + return ret; +} + template std::string IntTypeToString() { LOG(FATAL) << "Not support type."; diff --git a/paddle/fluid/framework/ir/xpu/pass_utils.h b/paddle/fluid/framework/ir/xpu/pass_utils.h index 68cfb2953e1d5b2fe1c4f39aa30552a50209db2d..d1e7b218a0b4683c7b9775bb8cf9e0b65b731c84 100644 --- a/paddle/fluid/framework/ir/xpu/pass_utils.h +++ b/paddle/fluid/framework/ir/xpu/pass_utils.h @@ -51,6 +51,9 @@ int ConvertActivationType(std::string act_type); Node* FindNodeWithName(Graph* graph, std::string name); +std::vector FindOpNodeByInputName(Graph* graph, + const std::string& var_name); + template size_t HashTensor(const phi::DenseTensor& in); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 134d56180b6177350b191e947bcfab4ade59f75d..580578f2a04ad95051f2f074a7d5fc25fc5ff15f 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -523,7 +523,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "one_beam_size_fuse_pass", "delete_cast_op_pass", "stack_fuse_pass", - "fused_multi_transformer_xpu_quant_pass", + "fused_multi_transformer_xpu_pass", "fc_xpu_fuse_pass", "conv2d_xpu_fuse_pass", "link_xpu_op_max_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index f4570f711412d6fc0e22ab6e356761ca37c0af46..87400ecd61c9f4a0a2c3ffb8526c571319e8c908 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -58,14 +58,14 @@ support_dygraph_mode : true - op : fused_multi_transformer_xpu - args : (Tensor x, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] qkvw, Tensor[] qkvw_max, Tensor[] qkv_bias, Tensor[] out_linear_w, Tensor[] out_linear_wmax, Tensor[] out_linear_bias, Tensor[] ffn_ln_scale, Tensor[] ffn_ln_bias, Tensor[] ffn1_weight, Tensor[] ffn1_weight_max, Tensor[] ffn1_bias, Tensor[] ffn2_weight, Tensor[] ffn2_weight_max, Tensor[] ffn2_bias, Tensor[] cache_kv, Tensor[] pre_caches, Tensor rotary_pos_emb, Tensor time_step, Tensor seq_lengths, Tensor src_mask, bool pre_layer_norm, int rotary_emb_dims, float epsilon, float dropout_rate, bool is_test, str dropout_implementation, str act_method, bool trans_qkvw, int ring_id) + args : (Tensor x, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] qkvw, Tensor[] qkvw_max, Tensor[] qkv_bias, Tensor[] out_linear_w, Tensor[] out_linear_wmax, Tensor[] out_linear_bias, Tensor[] ffn_ln_scale, Tensor[] ffn_ln_bias, Tensor[] ffn1_weight, Tensor[] ffn1_weight_max, Tensor[] ffn1_bias, Tensor[] ffn2_weight, Tensor[] ffn2_weight_max, Tensor[] ffn2_bias, Tensor[] cache_kv, Tensor[] pre_caches, Tensor rotary_pos_emb, Tensor time_step, Tensor seq_lengths, Tensor src_mask, Tensor gather_index, bool pre_layer_norm, int rotary_emb_dims, float epsilon, float dropout_rate, bool is_test, str dropout_implementation, str act_method, bool trans_qkvw, int ring_id, int gather_axis) output : Tensor(out), Tensor[](cache_kv_out){out_linear_w.size()} infer_meta : func : FusedMultiTransformerXpuInferMeta kernel : func : fused_multi_transformer_xpu data_type : x - optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask + optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index - op : generate_sequence_xpu args : (Tensor x, DataType dtype) diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index fb7ede4c9e7423201818135f6a8a64e43c3d953c..262c5fb04e19f08a16a1deb7e4e529cf698b2180 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -278,6 +278,7 @@ void FusedMultiTransformerXpuInferMeta( const std::vector& time_step, const std::vector& seq_lengths, const std::vector& src_mask, + const std::vector& gather_index, bool pre_layer_norm, int rotary_emb_dims, float epsilon, @@ -287,6 +288,7 @@ void FusedMultiTransformerXpuInferMeta( const std::string& act_method, bool trans_qkvw, int ring_id, + int gather_axis, MetaTensor* out, std::vector cache_kv_out) { auto x_dim = x.dims(); @@ -325,13 +327,6 @@ void FusedMultiTransformerXpuInferMeta( phi::errors::InvalidArgument( "The first dim of CacheKV must be 2, but got %d", c_dim[0])); // 2 - PADDLE_ENFORCE_EQ( - c_dim[2], - x_dim[0], - phi::errors::InvalidArgument("The third dim of CacheKV must be equal " - "with batch size %d, but got %d", - x_dim[0], - c_dim[2])); // batch_size PADDLE_ENFORCE_EQ( c_dim[3], trans_qkvw ? y_dim[1] : y_dim[2], diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 3105ea8a6d578132740784b4c9c14967fb2e6526..38f4bc8c6c5be9c85896f80d57a562f7bf42257f 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -108,6 +108,7 @@ void FusedMultiTransformerXpuInferMeta( const std::vector& time_step, const std::vector& seq_lengths, const std::vector& src_mask, + const std::vector& gather_index, bool pre_layer_norm, int rotary_emb_dims, float epsilon, @@ -117,6 +118,7 @@ void FusedMultiTransformerXpuInferMeta( const std::string& act_method, bool trans_qkvw, int ring_id, + int gather_axis, MetaTensor* out, std::vector cache_kv_out); } // namespace phi diff --git a/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc index 0d27a398e449be95b618a32a3163fed7458d3ab0..29a3c7d61adffde21a4520c9a49044e3b16ff0b6 100644 --- a/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc @@ -17,6 +17,8 @@ #include "glog/logging.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/infermeta/binary.h" #include "paddle/phi/kernels/memcpy_kernel.h" #ifdef PADDLE_WITH_XPU_XFT #include "models/fused_multi_transformer_op.h" @@ -52,6 +54,7 @@ void FusedMultiTransformerXpuKernel( const paddle::optional& time_step, const paddle::optional& seq_lengths, const paddle::optional& src_mask, + const paddle::optional& gather_index, bool pre_layer_norm, int rotary_emb_dims, float epsilon, @@ -61,6 +64,7 @@ void FusedMultiTransformerXpuKernel( const std::string& act_method, bool trans_qkvw, int ring_id, + int gather_axis, DenseTensor* out, std::vector cache_kv_out) { #ifdef PADDLE_WITH_XPU_XFT @@ -160,6 +164,21 @@ void FusedMultiTransformerXpuKernel( std::vector> xft_cache_kv; std::vector> xft_cache_kv_out; + // Create a temporary Tensor to store the gather output of cache_kv + auto gather_index_t = gather_index.get_ptr(); + auto cache_kv_dims = cache_kv.get_ptr()->at(0)->dims(); + auto cache_kv_gather_dims = cache_kv_dims; + phi::DenseTensor cache_kv_gather_tensor; + if (gather_index_t) { + MetaTensor cache_kv_gather_meta(&cache_kv_gather_tensor); + phi::GatherInferMeta(*cache_kv.get_ptr()->at(0), + *gather_index_t, + Scalar(gather_axis), + &cache_kv_gather_meta); + cache_kv_gather_dims = cache_kv_gather_meta.dims(); + ctx.template Alloc(&cache_kv_gather_tensor); + } + int layers = qkvw.size(); for (int i = 0; i < layers; ++i) { // step1. layer_norm @@ -211,27 +230,55 @@ void FusedMultiTransformerXpuKernel( xft_ffn2_bias.emplace_back(const_cast(ffn2_bias[i]->data()), std::array{ffn2_bias[i]->dims()[0]}); // cache kv in - if (time_step_value > 0) { - auto cachekv_dims = cache_kv.get_ptr()->at(i)->dims(); - xft_cache_kv.emplace_back(reinterpret_cast(const_cast( - cache_kv.get_ptr()->at(i)->data())), - std::array{cachekv_dims[0], - cachekv_dims[1], - cachekv_dims[2], - cachekv_dims[3], - cachekv_dims[4]}); + auto cache_kv_data = reinterpret_cast( + const_cast(cache_kv.get_ptr()->at(i)->data())); + if (gather_index_t) { + const auto& index_type = gather_index_t->dtype(); + if (index_type == DataType::INT32) { + r = xpu::gather( + ctx.x_context(), + cache_kv_data, + gather_index_t->data(), + reinterpret_cast(cache_kv_gather_tensor.data()), + phi::vectorize(cache_kv_dims), + gather_index_t->dims().size() == 0 ? 1 : gather_index_t->dims()[0], + gather_axis); + } else { + r = xpu::gather( + ctx.x_context(), + cache_kv_data, + gather_index_t->data(), + reinterpret_cast(cache_kv_gather_tensor.data()), + phi::vectorize(cache_kv_dims), + gather_index_t->dims().size() == 0 ? 1 : gather_index_t->dims()[0], + gather_axis); + } + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu::gather"); + cache_kv_out[i]->ResizeAndAllocate(cache_kv_gather_dims); + r = xpu::copy( + ctx.x_context(), + reinterpret_cast(cache_kv_gather_tensor.data()), + reinterpret_cast(ctx.template Alloc(cache_kv_out[i])), + cache_kv_out[i]->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu::copy"); } - // cache kv out - auto cachekv_out_dims = cache_kv_out[i]->dims(); + cache_kv_data = reinterpret_cast( + const_cast(cache_kv.get_ptr()->at(i)->data())); + xft_cache_kv.emplace_back(cache_kv_data, + std::array{cache_kv_gather_dims[0], + cache_kv_gather_dims[1], + cache_kv_gather_dims[2], + cache_kv_gather_dims[3], + cache_kv_gather_dims[4]}); + // cache kv out direct use cache_kv_data xft_cache_kv_out.emplace_back( - reinterpret_cast(ctx.template Alloc(cache_kv_out[i])), - std::array{cachekv_out_dims[0], - cachekv_out_dims[1], - cachekv_out_dims[2], - cachekv_out_dims[3], - cachekv_out_dims[4]}); + cache_kv_data, + std::array{cache_kv_gather_dims[0], + cache_kv_gather_dims[1], + cache_kv_gather_dims[2], + cache_kv_gather_dims[3], + cache_kv_gather_dims[4]}); } - xft::NlpParam param; param.num_layer = layers; param.n_head = num_head;