From 52e1742fd4431c4b10e2ac87d38f35f93cbe0d08 Mon Sep 17 00:00:00 2001 From: mayang002 <77949147+mayang002@users.noreply.github.com> Date: Mon, 20 Mar 2023 14:23:06 +0800 Subject: [PATCH] [xpu] fused_multi_transformer_xpu pass&kernel support (#51571) --- cmake/external/xpu.cmake | 2 + paddle/fluid/framework/ir/CMakeLists.txt | 6 + ...use_multi_transformer_layer_pass_tester.cc | 8 +- paddle/fluid/framework/ir/node.h | 9 + paddle/fluid/framework/ir/pass.cc | 1 + .../fluid/framework/ir/pass_tester_helper.h | 58 +- .../fused_multi_transformer_xpu_quant_pass.cc | 546 ++++++++++++++++++ ...multi_transformer_xpu_quant_pass_tester.cc | 170 ++++++ .../inference/api/paddle_pass_builder.cc | 1 + paddle/phi/api/yaml/static_ops.yaml | 10 + paddle/phi/backends/xpu/xpu2_op_list.cc | 2 + paddle/phi/infermeta/fusion.cc | 104 ++++ paddle/phi/infermeta/fusion.h | 35 ++ .../xpu/fused_multi_transformer_xpu_kernel.cc | 275 +++++++++ 14 files changed, 1196 insertions(+), 31 deletions(-) create mode 100644 paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc create mode 100644 paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass_tester.cc create mode 100644 paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index a64851c7abe..b930e5557e2 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -142,6 +142,8 @@ if(WITH_XPU_XFT) message(STATUS "Compile with XPU XFT!") add_definitions(-DPADDLE_WITH_XPU_XFT) + set(XPU_XFT_INC_DIR "${XPU_INC_DIR}/xft") + include_directories(${XPU_XFT_INC_DIR}) set(XPU_XFT_LIB "${XPU_LIB_DIR}/${XPU_XFT_LIB_NAME}") endif() diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 81b12bd1ee6..970f6f32c58 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -235,6 +235,8 @@ if(WITH_XPU) pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(delete_isolated_node_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) endif() cc_library( @@ -493,4 +495,8 @@ if(WITH_XPU) test_delete_isolated_node_pass SRCS xpu/delete_isolated_node_pass_test.cc DEPS delete_isolated_node_pass) + cc_test( + test_fused_multi_transformer_xpu_quant_pass + SRCS xpu/fused_multi_transformer_xpu_quant_pass_tester.cc + DEPS fused_multi_transformer_xpu_quant_pass) endif() diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc index c96935a9ac6..0821c5bdb71 100644 --- a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc @@ -75,7 +75,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) { 1, {2, -1, 16, 1024, 64}, 0); - auto* out = layers.fused_multi_transformer(x, + auto outs = layers.fused_multi_transformer(x, cache_kv, src_mask, qkv_w, @@ -93,7 +93,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) { 0.1, 1e-12); - x = out; + x = outs[0]; } std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); @@ -126,7 +126,7 @@ TEST(FuseMultiTransformerLayerPass, decoder_fp) { for (int i = 0; i < num_layers; ++i) { auto* shape_out = layers.shape(src_mask); auto* time_stamp = layers.slice(shape_out, {0}, {3}, {4}); - auto* out = layers.fused_multi_transformer(x, + auto outs = layers.fused_multi_transformer(x, cache_kv, src_mask, qkv_w, @@ -145,7 +145,7 @@ TEST(FuseMultiTransformerLayerPass, decoder_fp) { 1e-12, time_stamp); - x = out; + x = outs[0]; } std::unique_ptr graph(new ir::Graph(layers.main_program())); auto param_scope = CreateParamScope(); diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 1381ea64d7b..fbd26b42c35 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -151,6 +151,15 @@ class Node { var_desc_->SetName(new_name); } + void RenameOp(const std::string& new_name) { + PADDLE_ENFORCE_EQ( + type_ == Type::kOperation && op_desc_, + true, + platform::errors::InvalidArgument("Node must be type of variable.")); + name_ = new_name; + op_desc_->SetType(new_name); + } + int DescOrder() const { return desc_order_; } int GetVarNodeBlockId() const { diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index c064040cf42..623e5b7a357 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -49,6 +49,7 @@ static const std::vector support_subgraph_passes = { "fuse_multi_transformer_layer_pass", "delete_quant_dequant_linear_op_pass", "delete_weight_dequant_linear_op_pass", + "fused_multi_transformer_xpu_quant_pass", "fc_xpu_fuse_pass", "delete_op_device_pass"}; diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index dc423d9d17d..07b4034823b 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -571,33 +571,35 @@ struct Layers { return out; } - VarDesc* fused_multi_transformer(VarDesc* x, - VarDesc* cache_kv, - VarDesc* src_mask, - VarDesc* qkv_w, - VarDesc* qkv_bias, - VarDesc* out_linear_w, - VarDesc* out_linear_bias, - VarDesc* ffn1_w, - VarDesc* ffn1_bias, - VarDesc* ffn2_w, - VarDesc* ffn2_bias, - VarDesc* ln_scale, - VarDesc* ln_bias, - VarDesc* ffn_ln_scale, - VarDesc* ffn_ln_bias, - float epsilon, - float dropout_rate, - VarDesc* time_stamp = nullptr, - VarDesc* qkv_out_scale = nullptr, - VarDesc* out_linear_out_scale = nullptr, - VarDesc* ffn1_out_scale = nullptr, - VarDesc* ffn2_out_scale = nullptr, - std::vector qkv_in_scale = {}, - std::vector out_linear_in_scale = {}, - std::vector ffn1_in_scale = {}, - std::vector ffn2_in_scale = {}) { + std::vector fused_multi_transformer( + VarDesc* x, + VarDesc* cache_kv, + VarDesc* src_mask, + VarDesc* qkv_w, + VarDesc* qkv_bias, + VarDesc* out_linear_w, + VarDesc* out_linear_bias, + VarDesc* ffn1_w, + VarDesc* ffn1_bias, + VarDesc* ffn2_w, + VarDesc* ffn2_bias, + VarDesc* ln_scale, + VarDesc* ln_bias, + VarDesc* ffn_ln_scale, + VarDesc* ffn_ln_bias, + float epsilon, + float dropout_rate, + VarDesc* time_stamp = nullptr, + VarDesc* qkv_out_scale = nullptr, + VarDesc* out_linear_out_scale = nullptr, + VarDesc* ffn1_out_scale = nullptr, + VarDesc* ffn2_out_scale = nullptr, + std::vector qkv_in_scale = {}, + std::vector out_linear_in_scale = {}, + std::vector ffn1_in_scale = {}, + std::vector ffn2_in_scale = {}) { VarDesc* out = lod_tensor(unique_name()); + VarDesc* cache_kv_out = lod_tensor(unique_name()); OpDesc* op = program_.MutableBlock(0)->AppendOp(); std::string op_type = qkv_out_scale ? "fused_multi_transformer_int8" : "fused_multi_transformer"; @@ -623,6 +625,7 @@ struct Layers { op->SetAttr("dropout_rate", dropout_rate); op->SetAttr("epsilon", epsilon); op->SetOutput("Out", {out->Name()}); + op->SetOutput("CacheKVOut", {cache_kv_out->Name()}); if (time_stamp) { op->SetInput("TimeStep", {time_stamp->Name()}); @@ -638,7 +641,8 @@ struct Layers { op->SetAttr("ffn1_in_scale", ffn1_in_scale); op->SetAttr("ffn2_in_scale", ffn2_in_scale); } - return out; + std::vector outs = {out, cache_kv_out}; + return outs; } VarDesc* dequantize_linear(VarDesc* x, diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc new file mode 100644 index 00000000000..43e8821cd88 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc @@ -0,0 +1,546 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct FusedMultiTransformerPattern : public PatternBase { + FusedMultiTransformerPattern(PDPattern* pattern, + const std::string& name_scope, + bool with_cache_kv, + bool with_pre_caches, + bool with_rotary_pos_emb, + bool with_time_step, + bool with_seq_lengths, + bool with_src_mask); + + // declare operator node's name + PATTERN_DECL_NODE(fused_mt); + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(ln_scale); + PATTERN_DECL_NODE(ln_bias); + PATTERN_DECL_NODE(qkv_w); + PATTERN_DECL_NODE(qkv_bias); + PATTERN_DECL_NODE(cache_kv); + PATTERN_DECL_NODE(pre_caches); + PATTERN_DECL_NODE(rotary_pos_emb); + PATTERN_DECL_NODE(time_step); + PATTERN_DECL_NODE(seq_lengths); + PATTERN_DECL_NODE(src_mask); + PATTERN_DECL_NODE(out_linear_w); + PATTERN_DECL_NODE(out_linear_bias); + PATTERN_DECL_NODE(ffn_ln_scale); + PATTERN_DECL_NODE(ffn_ln_bias); + PATTERN_DECL_NODE(ffn1_w); + PATTERN_DECL_NODE(ffn1_bias); + PATTERN_DECL_NODE(ffn2_w); + PATTERN_DECL_NODE(ffn2_bias); + PATTERN_DECL_NODE(cache_kv_out); + PATTERN_DECL_NODE(out); + + private: + bool with_cache_kv_{false}; + bool with_pre_caches_{false}; + bool with_rotary_pos_emb_{false}; + bool with_time_step_{false}; + bool with_seq_lengths_{false}; + bool with_src_mask_{false}; +}; + +FusedMultiTransformerPattern::FusedMultiTransformerPattern( + PDPattern* pattern, + const std::string& name_scope, + bool with_cache_kv, + bool with_pre_caches, + bool with_rotary_pos_emb, + bool with_time_step, + bool with_seq_lengths, + bool with_src_mask) + : PatternBase(pattern, name_scope, name_scope), + with_cache_kv_(with_cache_kv), + with_pre_caches_(with_pre_caches), + with_rotary_pos_emb_(with_rotary_pos_emb), + with_time_step_(with_time_step), + with_seq_lengths_(with_seq_lengths), + with_src_mask_(with_src_mask) { + std::string op_type = "fused_multi_transformer"; + auto* fused_mt = pattern->NewNode(fused_mt_repr())->assert_is_op(op_type); + // inputs and outputs + auto* x = pattern->NewNode(x_repr()) + ->assert_is_op_input(op_type, "X") + ->assert_var_not_persistable(); + auto* cache_kv_out = pattern->NewNode(cache_kv_out_repr()) + ->assert_is_op_output(op_type, "CacheKVOut") + ->assert_var_not_persistable(); + auto* out = pattern->NewNode(out_repr()) + ->assert_is_op_output(op_type, "Out") + ->assert_var_not_persistable(); + // weights and biases + auto* ln_scale = pattern->NewNode(ln_scale_repr()) + ->assert_is_op_input(op_type, "LnScale") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 1; + }); + auto* ln_bias = pattern->NewNode(ln_bias_repr()) + ->assert_is_op_input(op_type, "LnBias") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 1; + }); + auto* qkv_w = pattern->NewNode(qkv_w_repr()) + ->assert_is_op_input(op_type, "QKVW") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 4; + }); + auto* qkv_bias = pattern->NewNode(qkv_bias_repr()) + ->assert_is_op_input(op_type, "QKVBias") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 3; + }); + auto* out_linear_w = pattern->NewNode(out_linear_w_repr()) + ->assert_is_op_input(op_type, "OutLinearW") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 2; + }); + auto* out_linear_bias = pattern->NewNode(out_linear_bias_repr()) + ->assert_is_op_input(op_type, "OutLinearBias") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 1; + }); + auto* ffn_ln_scale = pattern->NewNode(ffn_ln_scale_repr()) + ->assert_is_op_input(op_type, "FFNLnScale") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 1; + }); + auto* ffn_ln_bias = pattern->NewNode(ffn_ln_bias_repr()) + ->assert_is_op_input(op_type, "FFNLnBias") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 1; + }); + auto* ffn1_w = pattern->NewNode(ffn1_w_repr()) + ->assert_is_op_input(op_type, "FFN1Weight") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 2; + }); + auto* ffn1_bias = pattern->NewNode(ffn1_bias_repr()) + ->assert_is_op_input(op_type, "FFN1Bias") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 1; + }); + auto* ffn2_w = pattern->NewNode(ffn2_w_repr()) + ->assert_is_op_input(op_type, "FFN2Weight") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 2; + }); + auto* ffn2_bias = pattern->NewNode(ffn2_bias_repr()) + ->assert_is_op_input(op_type, "FFN2Bias") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 1; + }); + + std::vector input_vars{x, + ln_scale, + ln_bias, + qkv_w, + qkv_bias, + out_linear_w, + out_linear_bias, + ffn_ln_scale, + ffn_ln_bias, + ffn1_w, + ffn1_bias, + ffn2_w, + ffn2_bias}; + std::vector output_vars{cache_kv_out, out}; + + // optional node + PDNode* cache_kv = nullptr; + PDNode* pre_caches = nullptr; + PDNode* rotary_pos_emb = nullptr; + PDNode* time_step = nullptr; + PDNode* seq_lengths = nullptr; + PDNode* src_mask = nullptr; + if (with_cache_kv_) { + cache_kv = pattern->NewNode(cache_kv_repr()) + ->assert_is_op_input(op_type, "CacheKV") + ->assert_var_not_persistable(); + input_vars.push_back(cache_kv); + } + if (with_pre_caches_) { + pre_caches = pattern->NewNode(pre_caches_repr()) + ->assert_is_op_input(op_type, "PreCaches") + ->assert_var_not_persistable(); + input_vars.push_back(pre_caches); + } + if (with_rotary_pos_emb_) { + rotary_pos_emb = pattern->NewNode(rotary_pos_emb_repr()) + ->assert_is_op_input(op_type, "RotaryPosEmb") + ->assert_var_not_persistable(); + input_vars.push_back(rotary_pos_emb); + } + if (with_time_step_) { + time_step = pattern->NewNode(time_step_repr()) + ->assert_is_op_input(op_type, "TimeStep") + ->assert_var_not_persistable(); + input_vars.push_back(time_step); + } + if (with_seq_lengths_) { + seq_lengths = pattern->NewNode(seq_lengths_repr()) + ->assert_is_op_input(op_type, "SeqLengths") + ->assert_var_not_persistable(); + input_vars.push_back(seq_lengths); + } + if (with_src_mask_) { + src_mask = pattern->NewNode(src_mask_repr()) + ->assert_is_op_input(op_type, "SrcMask") + ->assert_var_not_persistable(); + input_vars.push_back(src_mask); + } + + fused_mt->LinksFrom(input_vars).LinksTo(output_vars); +} + +} // namespace patterns + +/* +1. transpose and quantify the weights of fused_multi_transformer op from fp32 to +int16 +*/ +class FusedMultiTransformerXPUQuantPass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + int ApplyImpl(ir::Graph* graph, + bool with_cache_kv, + bool with_pre_caches, + bool with_rotary_pos_emb, + bool with_time_step, + bool with_seq_lengths, + bool with_src_mask) const; + + const std::string name_scope_{"fused_multi_transformer_xpu_quant_pass"}; +}; + +void FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + VLOG(3) << "in FusedMultiTransformerXPUQuantPass::ApplyImpl"; + + int found_subgraph_count = 0; + for (bool with_time_step : {true, false}) { + found_subgraph_count += + ApplyImpl(graph, true, false, false, with_time_step, false, true); + } + AddStatis(found_subgraph_count); +} + +int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, + bool with_cache_kv, + bool with_pre_caches, + bool with_rotary_pos_emb, + bool with_time_step, + bool with_seq_lengths, + bool with_src_mask) const { + GraphPatternDetector gpd; + patterns::FusedMultiTransformerPattern pattern(gpd.mutable_pattern(), + name_scope_, + with_cache_kv, + with_pre_caches, + with_rotary_pos_emb, + with_time_step, + with_seq_lengths, + with_src_mask); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle FusedMultiTransformerXPUQuantPass fuse"; + + GET_IR_NODE(x); + GET_IR_NODE(ln_scale); + GET_IR_NODE(ln_bias); + GET_IR_NODE(qkv_w); + GET_IR_NODE(qkv_bias); + GET_IR_NODE(cache_kv); + GET_IR_NODE(pre_caches); + GET_IR_NODE(rotary_pos_emb); + GET_IR_NODE(time_step); + GET_IR_NODE(seq_lengths); + GET_IR_NODE(src_mask); + GET_IR_NODE(out_linear_w); + GET_IR_NODE(out_linear_bias); + GET_IR_NODE(ffn_ln_scale); + GET_IR_NODE(ffn_ln_bias); + GET_IR_NODE(ffn1_w); + GET_IR_NODE(ffn1_bias); + GET_IR_NODE(ffn2_w); + GET_IR_NODE(ffn2_bias); + GET_IR_NODE(cache_kv_out); + GET_IR_NODE(out); + GET_IR_NODE(fused_mt); + auto* block = fused_mt->Op()->Block(); + auto* scope = param_scope(); + + // quant weight nodes + // w_nodes_vec: [QKVW, OutLinearW, FFN1Weight, FFN2Weight] + std::vector> w_nodes_vec(4); + std::vector> w_int16_nodes_vec(4); + std::vector> w_max_nodes_vec(4); + std::vector> w_int16_names_vec(4); + std::vector> w_max_names_vec(4); + auto quant_func = [&](const std::string& input_name, + std::vector* w_nodes, + std::vector* w_int16_nodes, + std::vector* w_max_nodes, + std::vector* w_int16_names, + std::vector* w_max_names, + bool need_transpose) { + typedef int16_t TW; + + auto w_names = fused_mt->Op()->Input(input_name); + for (auto w_name : w_names) { + Node* w_node = FindNodeWithName(graph, w_name); + Node* w_int16 = nullptr; + Node* w_max = nullptr; + PADDLE_ENFORCE_NE( + w_node, + nullptr, + platform::errors::Fatal("w node should not be nullptr")); + PrepareWeight( + graph, scope, block, w_node, &w_int16, &w_max, need_transpose); + w_nodes->push_back(w_node); + w_int16_nodes->push_back(w_int16); + w_max_nodes->push_back(w_max); + } + for (size_t i = 0; i < w_names.size(); ++i) { + w_int16_names->push_back(w_int16_nodes->at(i)->Name()); + w_max_names->push_back(w_max_nodes->at(i)->Name()); + } + PADDLE_ENFORCE_EQ( + w_names.size(), + w_nodes->size(), + platform::errors::Fatal( + "The size of w_names(%d) should be equal to w_nodes(%d)", + static_cast(w_names.size()), + static_cast(w_nodes->size()))); + PADDLE_ENFORCE_EQ( + w_names.size(), + w_int16_nodes->size(), + platform::errors::Fatal( + "The size of w_names(%d) should be equal to w_int16_nodes(%d)", + static_cast(w_names.size()), + static_cast(w_int16_nodes->size()))); + PADDLE_ENFORCE_EQ( + w_names.size(), + w_max_nodes->size(), + platform::errors::Fatal( + "The size of w_names(%d) should be equal to w_max_nodes(%d)", + static_cast(w_names.size()), + static_cast(w_max_nodes->size()))); + PADDLE_ENFORCE_EQ( + w_names.size(), + w_int16_names->size(), + platform::errors::Fatal( + "The size of w_names(%d) should be equal to w_int16_names(%d)", + static_cast(w_names.size()), + static_cast(w_int16_names->size()))); + PADDLE_ENFORCE_EQ( + w_names.size(), + w_max_names->size(), + platform::errors::Fatal( + "The size of w_names(%d) should be equal to w_max_names(%d)", + static_cast(w_names.size()), + static_cast(w_max_names->size()))); + }; + quant_func("QKVW", + &(w_nodes_vec[0]), + &(w_int16_nodes_vec[0]), + &(w_max_nodes_vec[0]), + &(w_int16_names_vec[0]), + &(w_max_names_vec[0]), + false); + quant_func("OutLinearW", + &(w_nodes_vec[1]), + &(w_int16_nodes_vec[1]), + &(w_max_nodes_vec[1]), + &(w_int16_names_vec[1]), + &(w_max_names_vec[1]), + true); + quant_func("FFN1Weight", + &(w_nodes_vec[2]), + &(w_int16_nodes_vec[2]), + &(w_max_nodes_vec[2]), + &(w_int16_names_vec[2]), + &(w_max_names_vec[2]), + true); + quant_func("FFN2Weight", + &(w_nodes_vec[3]), + &(w_int16_nodes_vec[3]), + &(w_max_nodes_vec[3]), + &(w_int16_names_vec[3]), + &(w_max_names_vec[3]), + true); + + // cast some nodes to fp32 nodes + std::vector fp32_nodes; + auto cast_tofp32_func = [&](const std::string& input_name) { + auto names = fused_mt->Op()->Input(input_name); + for (auto name : names) { + auto* curr_tensor = scope->Var(name)->GetMutable(); + PADDLE_ENFORCE_NE( + curr_tensor, + nullptr, + platform::errors::Fatal("tensor node should not be nullptr")); + CastToFp32(curr_tensor); + + Node* curr_node = FindNodeWithName(graph, name); + fp32_nodes.push_back(curr_node); + } + }; + cast_tofp32_func("LnScale"); + cast_tofp32_func("LnBias"); + cast_tofp32_func("QKVBias"); + cast_tofp32_func("OutLinearBias"); + cast_tofp32_func("FFNLnScale"); + cast_tofp32_func("FFNLnBias"); + cast_tofp32_func("FFN1Bias"); + cast_tofp32_func("FFN2Bias"); + + // Generate fused_multi_transformer_xpu op inplace + fused_mt->RenameOp("fused_multi_transformer_xpu"); + framework::OpDesc* fused_mt_xpu_op_desc = fused_mt->Op(); + fused_mt_xpu_op_desc->SetType("fused_multi_transformer_xpu"); + std::unordered_map> name_caches; + for (auto key : fused_mt_xpu_op_desc->InputNames()) { + name_caches.insert({key, fused_mt_xpu_op_desc->Input(key)}); + } + for (auto key : fused_mt_xpu_op_desc->OutputNames()) { + name_caches.insert({key, fused_mt_xpu_op_desc->Output(key)}); + } + fused_mt_xpu_op_desc->MutableInputs()->clear(); + fused_mt_xpu_op_desc->MutableOutputs()->clear(); + fused_mt_xpu_op_desc->SetInput("x", name_caches.at("X")); + fused_mt_xpu_op_desc->SetInput("ln_scale", name_caches.at("LnScale")); + fused_mt_xpu_op_desc->SetInput("ln_bias", name_caches.at("LnBias")); + fused_mt_xpu_op_desc->SetInput("qkv_bias", name_caches.at("QKVBias")); + if (cache_kv) { + fused_mt_xpu_op_desc->SetInput("cache_kv", name_caches.at("CacheKV")); + } + if (pre_caches) { + fused_mt_xpu_op_desc->SetInput("pre_caches", name_caches.at("PreCaches")); + } + if (rotary_pos_emb) { + fused_mt_xpu_op_desc->SetInput("rotary_pos_emb", + name_caches.at("RotaryPosEmb")); + } + if (time_step) { + fused_mt_xpu_op_desc->SetInput("time_step", name_caches.at("TimeStep")); + } + if (seq_lengths) { + fused_mt_xpu_op_desc->SetInput("seq_lengths", + name_caches.at("SeqLengths")); + } + if (src_mask) { + fused_mt_xpu_op_desc->SetInput("src_mask", name_caches.at("SrcMask")); + } + fused_mt_xpu_op_desc->SetInput("out_linear_bias", + name_caches.at("OutLinearBias")); + fused_mt_xpu_op_desc->SetInput("ffn_ln_scale", + name_caches.at("FFNLnScale")); + fused_mt_xpu_op_desc->SetInput("ffn_ln_bias", name_caches.at("FFNLnBias")); + fused_mt_xpu_op_desc->SetInput("ffn1_bias", name_caches.at("FFN1Bias")); + fused_mt_xpu_op_desc->SetInput("ffn2_bias", name_caches.at("FFN2Bias")); + fused_mt_xpu_op_desc->SetOutput("cache_kv_out", + name_caches.at("CacheKVOut")); + fused_mt_xpu_op_desc->SetOutput("out", name_caches.at("Out")); + + fused_mt_xpu_op_desc->SetInput("qkvw", w_int16_names_vec[0]); + fused_mt_xpu_op_desc->SetInput("qkvw_max", w_max_names_vec[0]); + fused_mt_xpu_op_desc->SetInput("out_linear_w", w_int16_names_vec[1]); + fused_mt_xpu_op_desc->SetInput("out_linear_wmax", w_max_names_vec[1]); + fused_mt_xpu_op_desc->SetInput("ffn1_weight", w_int16_names_vec[2]); + fused_mt_xpu_op_desc->SetInput("ffn1_weight_max", w_max_names_vec[2]); + fused_mt_xpu_op_desc->SetInput("ffn2_weight", w_int16_names_vec[3]); + fused_mt_xpu_op_desc->SetInput("ffn2_weight_max", w_max_names_vec[3]); + if (!fused_mt_xpu_op_desc->HasAttr("rotary_emb_dims")) { + fused_mt_xpu_op_desc->SetAttr("rotary_emb_dims", 0); + } + // unlink QKVW/OutLinearW/FFN1Weight/FFN2Weight from fused_mt_xpu + for (auto nodes : w_nodes_vec) { + for (auto node : nodes) { + IR_NODE_UNLINK(node, fused_mt); + } + } + // link int16 format of QKVW/OutLinearW/FFN1Weight/FFN2Weight to + // fused_mt_xpu + for (auto nodes : w_int16_nodes_vec) { + for (auto node : nodes) { + IR_NODE_LINK_TO(node, fused_mt); + } + } + // link QKVWMax/OutLinearWMax/FFN1WeightMax/FFN2WeightMax to fused_mt_xpu + for (auto nodes : w_max_nodes_vec) { + for (auto node : nodes) { + IR_NODE_LINK_TO(node, fused_mt); + } + } + + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fused_multi_transformer_xpu_quant_pass, + paddle::framework::ir::FusedMultiTransformerXPUQuantPass); diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass_tester.cc b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass_tester.cc new file mode 100644 index 00000000000..d3181dbde46 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass_tester.cc @@ -0,0 +1,170 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +#define DEF_INPUT_DATA \ + Layers layers; \ + auto* x = layers.data("x", {1, 128, 1024}); \ + auto* src_mask = layers.data("src_mask", {1, 16, 128, 128}); \ + auto* ln_scale = layers.data("ln_scale", {1024}, true); \ + auto* ln_bias = layers.data("ln_bias", {1024}, true); \ + auto* qkv_w = layers.data("qkv_w", {3, 16, 64, 1024}, true); \ + auto* qkv_bias = layers.data("qkv_bias", {3, 16, 64}, true); \ + auto* out_linear_w = layers.data("out_linear_w", {1024, 1024}, true); \ + auto* out_linear_bias = layers.data("out_linear_bias", {1024}, true); \ + auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); \ + auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); \ + auto* ffn1_w = layers.data("ffn1_w", {1024, 4096}, true); \ + auto* ffn1_bias = layers.data("ffn1_bias", {4096}, true); \ + auto* ffn2_w = layers.data("ffn2_w", {4096, 1024}, true); \ + auto* ffn2_bias = layers.data("ffn2_bias", {1024}, true); + +namespace paddle { +namespace framework { +namespace ir { + +void AddVarToScope(Scope* param_scope, + const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "ln_scale", {1024}); + AddVarToScope(param_scope, "ln_bias", {1024}); + AddVarToScope(param_scope, "ffn_ln_scale", {1024}); + AddVarToScope(param_scope, "ffn_ln_bias", {1024}); + + AddVarToScope(param_scope, "qkv_w", {3, 16, 64, 1024}); + AddVarToScope(param_scope, "out_linear_w", {1024, 1024}); + AddVarToScope(param_scope, "ffn1_w", {1024, 4096}); + AddVarToScope(param_scope, "ffn2_w", {4096, 1024}); + AddVarToScope(param_scope, "qkv_bias", {3072}); + AddVarToScope(param_scope, "out_linear_bias", {1024}); + AddVarToScope(param_scope, "ffn1_bias", {4096}); + AddVarToScope(param_scope, "ffn2_bias", {1024}); + + return param_scope; +} + +TEST(FusedMultiTransformerXPUQuantPass, context_stage) { + DEF_INPUT_DATA + + auto* cache_kv = layers.fill_constant_batch_size_like( + x, + static_cast(proto::VarType::FP32), + 0, + 1, + {2, -1, 16, 1024, 64}, + 0); + + layers.fused_multi_transformer(x, + cache_kv, + src_mask, + qkv_w, + qkv_bias, + out_linear_w, + out_linear_bias, + ffn1_w, + ffn1_bias, + ffn2_w, + ffn2_bias, + ln_scale, + ln_bias, + ffn_ln_scale, + ffn_ln_bias, + 0.1, + 1e-12); + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + + auto pass = + PassRegistry::Instance().Get("fused_multi_transformer_xpu_quant_pass"); + if (pass.get() == nullptr) { + LOG(INFO) << "get fused_multi_transformer_xpu_quant_pass failed"; + } + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer_xpu"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ( + num_nodes_after, + 1, + platform::errors::InvalidArgument( + "After the fuse_multi_transformer_layer_pass, " + "The node num in graph should be 1, but the result is %d", + num_nodes_after)); +} + +TEST(FusedMultiTransformerXPUQuantPass, decoder_stage) { + DEF_INPUT_DATA + + auto* cache_kv = layers.fill_constant_batch_size_like( + x, + static_cast(proto::VarType::FP32), + 0, + 1, + {2, -1, 16, 1024, 64}, + 0); + auto* time_step = layers.data("time_step", {1}); + layers.fused_multi_transformer(x, + cache_kv, + src_mask, + qkv_w, + qkv_bias, + out_linear_w, + out_linear_bias, + ffn1_w, + ffn1_bias, + ffn2_w, + ffn2_bias, + ln_scale, + ln_bias, + ffn_ln_scale, + ffn_ln_bias, + 0.1, + 1e-12, + time_step); + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + + auto pass = + PassRegistry::Instance().Get("fused_multi_transformer_xpu_quant_pass"); + if (pass.get() == nullptr) { + LOG(INFO) << "get fused_multi_transformer_xpu_quant_pass failed"; + } + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer_xpu"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ( + num_nodes_after, + 1, + platform::errors::InvalidArgument( + "After the fuse_multi_transformer_layer_pass, " + "The node num in graph should be 1, but the result is %d", + num_nodes_after)); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(fused_multi_transformer_xpu_quant_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index cb2bf7ad73e..c8cea5ca59b 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "embedding_with_eltwise_add_xpu_fuse_pass", "multi_encoder_xpu_fuse_pass", "multi_encoder_xpu_slice_fuse_pass", + "fused_multi_transformer_xpu_quant_pass", "fc_xpu_fuse_pass", "link_xpu_op_max_pass", "delete_op_device_pass", diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 8fa782e1863..a315d9c8086 100644 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -47,6 +47,16 @@ param : [x, axis, keepdim, reduce_all] backward : frobenius_norm_grad +- op : fused_multi_transformer_xpu + args : (Tensor x, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] qkvw, Tensor[] qkvw_max, Tensor[] qkv_bias, Tensor[] out_linear_w, Tensor[] out_linear_wmax, Tensor[] out_linear_bias, Tensor[] ffn_ln_scale, Tensor[] ffn_ln_bias, Tensor[] ffn1_weight, Tensor[] ffn1_weight_max, Tensor[] ffn1_bias, Tensor[] ffn2_weight, Tensor[] ffn2_weight_max, Tensor[] ffn2_bias, Tensor[] cache_kv, Tensor[] pre_caches, Tensor rotary_pos_emb, Tensor time_step, Tensor seq_lengths, Tensor src_mask, bool pre_layer_norm, int rotary_emb_dims, float epsilon, float dropout_rate, bool is_test, str dropout_implementation, str act_method, bool trans_qkvw, int ring_id) + output : Tensor(out), Tensor[](cache_kv_out){out_linear_w.size()} + infer_meta : + func : FusedMultiTransformerXpuInferMeta + kernel : + func : fused_multi_transformer_xpu + data_type : x + optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask + - op : generate_sequence_xpu args : (Tensor x, DataType dtype) output : Tensor diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 76234856d05..86739d05fb5 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -331,6 +331,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT32, phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"fused_multi_transformer_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"unfold", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"unfold_grad", diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index feb95cb32a9..4c4f6746a6c 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -114,4 +114,108 @@ void MultiEncoderXPUInferMeta( } } +void FusedMultiTransformerXpuInferMeta( + const MetaTensor& x, + const std::vector& ln_scale, + const std::vector& ln_bias, + const std::vector& qkvw, + const std::vector& qkvw_max, + const std::vector& qkv_bias, + const std::vector& out_linear_w, + const std::vector& out_linear_wmax, + const std::vector& out_linear_bias, + const std::vector& ffn_ln_scale, + const std::vector& ffn_ln_bias, + const std::vector& ffn1_weight, + const std::vector& ffn1_weight_max, + const std::vector& ffn1_bias, + const std::vector& ffn2_weight, + const std::vector& ffn2_weight_max, + const std::vector& ffn2_bias, + const std::vector& cache_kv, + const std::vector& pre_caches, + const std::vector& rotary_pos_emb, + const std::vector& time_step, + const std::vector& seq_lengths, + const std::vector& src_mask, + bool pre_layer_norm, + int rotary_emb_dims, + float epsilon, + float dropout_rate, + bool is_test, + const std::string& dropout_implementation, + const std::string& act_method, + bool trans_qkvw, + int ring_id, + MetaTensor* out, + std::vector cache_kv_out) { + auto x_dim = x.dims(); + auto y_dim = qkvw[0]->dims(); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 3, + phi::errors::InvalidArgument("The dimensions of x must be 3" + "(batch_size, seq_len, dim_embed)," + "but received dimensions of" + "Input is [%d]", + x_dim.size())); + PADDLE_ENFORCE_EQ( + y_dim.size(), + 4, + phi::errors::InvalidArgument("The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "but received dimensions of" + "Input is [%d]", + y_dim.size())); + PADDLE_ENFORCE_EQ( + x_dim[2], + trans_qkvw ? y_dim[3] : y_dim[0], + phi::errors::InvalidArgument( + "ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is " + "true) or y_dim[0](trans_qkvw is false)" + "must be equal. But received: the shape " + "of input x = [%s], and the shape of " + "input qkv_weight = [%s]", + x_dim, + y_dim)); + if (cache_kv.size() > 0) { + const auto& c_dim = cache_kv[0]->dims(); + PADDLE_ENFORCE_EQ( + c_dim.size(), + 5, + phi::errors::InvalidArgument("The CacheKV must be 5 dims, but got %d", + c_dim.size())); + PADDLE_ENFORCE_EQ(c_dim[0], + 2, + phi::errors::InvalidArgument( + "The first dim of CacheKV must be 2, but got %d", + c_dim[0])); // 2 + PADDLE_ENFORCE_EQ(c_dim[1], + x_dim[0], + phi::errors::InvalidArgument( + "The second dim of CacheKV must be equal with " + "batch size %d, but got %d", + x_dim[0], + c_dim[1])); // batch_size + PADDLE_ENFORCE_EQ(c_dim[2], + trans_qkvw ? y_dim[1] : y_dim[2], + phi::errors::InvalidArgument( + "The third dim of CacheKV must be equal with num " + "head %d, but got %d", + trans_qkvw ? y_dim[1] : y_dim[2], + c_dim[2])); // num_head + PADDLE_ENFORCE_EQ(c_dim[4], + trans_qkvw ? y_dim[2] : y_dim[3], + phi::errors::InvalidArgument( + "The fifth dim of CacheKV must be equal with head " + "size %d, but got %d", + trans_qkvw ? y_dim[2] : y_dim[3], + c_dim[4])); // head_size + } + + out->set_dims(x_dim); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); +} + } // namespace phi diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 6cba0552b1a..a08b6450bcf 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -66,4 +66,39 @@ void MultiEncoderXPUInferMeta( MetaTensor* x_fp16, MetaTensor* out_fp16); +void FusedMultiTransformerXpuInferMeta( + const MetaTensor& x, + const std::vector& ln_scale, + const std::vector& ln_bias, + const std::vector& qkvw, + const std::vector& qkvw_max, + const std::vector& qkv_bias, + const std::vector& out_linear_w, + const std::vector& out_linear_wmax, + const std::vector& out_linear_bias, + const std::vector& ffn_ln_scale, + const std::vector& ffn_ln_bias, + const std::vector& ffn1_weight, + const std::vector& ffn1_weight_max, + const std::vector& ffn1_bias, + const std::vector& ffn2_weight, + const std::vector& ffn2_weight_max, + const std::vector& ffn2_bias, + const std::vector& cache_kv, + const std::vector& pre_caches, + const std::vector& rotary_pos_emb, + const std::vector& time_step, + const std::vector& seq_lengths, + const std::vector& src_mask, + bool pre_layer_norm, + int rotary_emb_dims, + float epsilon, + float dropout_rate, + bool is_test, + const std::string& dropout_implementation, + const std::string& act_method, + bool trans_qkvw, + int ring_id, + MetaTensor* out, + std::vector cache_kv_out); } // namespace phi diff --git a/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc new file mode 100644 index 00000000000..5f49c24ea71 --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc @@ -0,0 +1,275 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/memcpy_kernel.h" +#ifdef PADDLE_WITH_XPU_XFT +#include "models/fused_multi_transformer_op.h" +namespace xft = baidu::xpu::xft; +#endif + +namespace phi { +namespace fusion { + +template +void FusedMultiTransformerXpuKernel( + const Context& ctx, + const DenseTensor& xx, + const std::vector& ln_scale, + const std::vector& ln_bias, + const std::vector& qkvw, + const std::vector& qkvw_max, + const std::vector& qkv_bias, + const std::vector& out_linear_w, + const std::vector& out_linear_wmax, + const std::vector& out_linear_bias, + const std::vector& ffn_ln_scale, + const std::vector& ffn_ln_bias, + const std::vector& ffn1_weight, + const std::vector& ffn1_weight_max, + const std::vector& ffn1_bias, + const std::vector& ffn2_weight, + const std::vector& ffn2_weight_max, + const std::vector& ffn2_bias, + const paddle::optional>& cache_kv, + const paddle::optional>& pre_caches, + const paddle::optional& rotary_pos_emb, + const paddle::optional& time_step, + const paddle::optional& seq_lengths, + const paddle::optional& src_mask, + bool pre_layer_norm, + int rotary_emb_dims, + float epsilon, + float dropout_rate, + bool is_test, + const std::string& dropout_implementation, + const std::string& act_method, + bool trans_qkvw, + int ring_id, + DenseTensor* out, + std::vector cache_kv_out) { +#ifdef PADDLE_WITH_XPU_XFT + using XPUTypeT = typename XPUTypeTrait::Type; + + PADDLE_ENFORCE_EQ(pre_layer_norm, + true, + phi::errors::PreconditionNotMet( + "Only support pre_layer_norm = true at now.")); + PADDLE_ENFORCE_EQ( + seq_lengths.get_ptr(), + nullptr, + phi::errors::PreconditionNotMet("seq_lengths not support at now.")); + PADDLE_ENFORCE_EQ( + rotary_pos_emb.get_ptr(), + nullptr, + phi::errors::PreconditionNotMet("rotary_pos_emb not support at now.")); + PADDLE_ENFORCE_EQ( + pre_caches.get_ptr(), + nullptr, + phi::errors::PreconditionNotMet("pre_caches not support at now.")); + PADDLE_ENFORCE_NE( + src_mask.get_ptr(), + nullptr, + phi::errors::PreconditionNotMet("src_mask should not be nullptr.")); + PADDLE_ENFORCE_EQ(trans_qkvw, + true, + phi::errors::PreconditionNotMet( + "Only support trans_qkvw == true at now.")); + + const auto x_dims = xx.dims(); + int seq_len = x_dims[1]; + const auto qkv_w_dims = qkvw[0]->dims(); + int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; + int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; + + int time_step_value = -1; + if (time_step) { + PADDLE_ENFORCE_EQ(time_step.get_ptr()->place(), + phi::CPUPlace(), + phi::errors::PreconditionNotMet( + "The place of input(time_step) must be CPUPlace.")); + // cache_seq_len + time_step_value = time_step.get_ptr()->data()[0]; + PADDLE_ENFORCE_GT( + time_step_value, + 0, + phi::errors::PreconditionNotMet( + "The value of time_step must > 0, but now is %d", time_step_value)); + PADDLE_ENFORCE_EQ( + seq_len, + 1, + phi::errors::PreconditionNotMet( + "In decode stage, the seq_len of input must be 1, but now is %d", + seq_len)); + } + + XPUTypeT* x_data = reinterpret_cast(const_cast(xx.data())); + XPUTypeT* src_mask_data = reinterpret_cast( + const_cast(src_mask.get_ptr()->data())); + auto* out_data = reinterpret_cast(ctx.template Alloc(out)); + auto src_mask_dims = src_mask.get_ptr()->dims(); + auto out_dims = out->dims(); + auto xft_x = xft::xftTensor( + x_data, std::array{x_dims[0], x_dims[1], x_dims[2]}); + // TODO(mayang02): xft support mask.dtype = float16 + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + float* src_mask_fp32_data = + RAII_GUARD.alloc(src_mask.get_ptr()->numel()); + int r = xpu::cast(ctx.x_context(), + src_mask_data, + src_mask_fp32_data, + src_mask.get_ptr()->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu::cast"); + auto xft_src_mask = + xft::xftTensor(src_mask_fp32_data, + std::array{src_mask_dims[0], + src_mask_dims[1], + src_mask_dims[2], + src_mask_dims[3]}); + auto xft_out = xft::xftTensor( + out_data, std::array{out_dims[0], out_dims[1], out_dims[2]}); + + typedef int16_t TW; + std::vector> xft_ln_scale; + std::vector> xft_ln_bias; + std::vector> xft_qkvw; + std::vector> xft_qkv_bias; + std::vector> xft_out_linear_w; + std::vector> xft_out_linear_bias; + std::vector> xft_ffn_ln_scale; + std::vector> xft_ffn_ln_bias; + std::vector> xft_ffn1_w; + std::vector> xft_ffn1_bias; + std::vector> xft_ffn2_w; + std::vector> xft_ffn2_bias; + std::vector> xft_cache_kv; + std::vector> xft_cache_kv_out; + + int layers = qkvw.size(); + for (int i = 0; i < layers; ++i) { + // step1. layer_norm + xft_ln_scale.emplace_back(const_cast(ln_scale[i]->data()), + std::array{ln_scale[i]->dims()[0]}); + xft_ln_bias.emplace_back(const_cast(ln_bias[i]->data()), + std::array{ln_bias[i]->dims()[0]}); + // step2. qkv + auto qkvw_dims = qkvw[i]->dims(); + xft_qkvw.emplace_back( + const_cast(qkvw[i]->data()), + const_cast(qkvw_max[i]->data()), + std::array{qkvw_dims[0] * qkvw_dims[1] * qkvw_dims[2], + qkvw_dims[3]}); + auto qkvb_dims = qkv_bias[i]->dims(); + xft_qkv_bias.emplace_back( + const_cast(qkv_bias[i]->data()), + std::array{qkvb_dims[0] * qkvb_dims[1] * qkvb_dims[2]}); + // attn out + auto outw_dims = out_linear_w[i]->dims(); + xft_out_linear_w.emplace_back( + const_cast(out_linear_w[i]->data()), + const_cast(out_linear_wmax[i]->data()), + std::array{outw_dims[0], outw_dims[1]}); + xft_out_linear_bias.emplace_back( + const_cast(out_linear_bias[i]->data()), + std::array{out_linear_bias[i]->dims()[0]}); + // ffn ln + xft_ffn_ln_scale.emplace_back( + const_cast(ffn_ln_scale[i]->data()), + std::array{ffn_ln_scale[i]->dims()[0]}); + xft_ffn_ln_bias.emplace_back( + const_cast(ffn_ln_bias[i]->data()), + std::array{ffn_ln_bias[i]->dims()[0]}); + // ffn1 + auto ffn1w_dims = ffn1_weight[i]->dims(); + xft_ffn1_w.emplace_back( + const_cast(ffn1_weight[i]->data()), + const_cast(ffn1_weight_max[i]->data()), + std::array{ffn1w_dims[0], ffn1w_dims[1]}); + xft_ffn1_bias.emplace_back(const_cast(ffn1_bias[i]->data()), + std::array{ffn1_bias[i]->dims()[0]}); + // ffn2 + auto ffn2w_dims = ffn2_weight[i]->dims(); + xft_ffn2_w.emplace_back( + const_cast(ffn2_weight[i]->data()), + const_cast(ffn2_weight_max[i]->data()), + std::array{ffn2w_dims[0], ffn2w_dims[1]}); + xft_ffn2_bias.emplace_back(const_cast(ffn2_bias[i]->data()), + std::array{ffn2_bias[i]->dims()[0]}); + // cache kv in + if (time_step_value > 0) { + auto cachekv_dims = cache_kv.get_ptr()->at(i)->dims(); + xft_cache_kv.emplace_back(reinterpret_cast(const_cast( + cache_kv.get_ptr()->at(i)->data())), + std::array{cachekv_dims[0], + cachekv_dims[1], + cachekv_dims[2], + cachekv_dims[3], + cachekv_dims[4]}); + } + // cache kv out + auto cachekv_out_dims = cache_kv_out[i]->dims(); + xft_cache_kv_out.emplace_back( + reinterpret_cast(ctx.template Alloc(cache_kv_out[i])), + std::array{cachekv_out_dims[0], + cachekv_out_dims[1], + cachekv_out_dims[2], + cachekv_out_dims[3], + cachekv_out_dims[4]}); + } + + xft::NlpParam param; + param.num_layer = layers; + param.n_head = num_head; + param.size_per_head = dim_head; + param.hidden_act = act_method; + param.is_fuse_qkv = true; + r = xft::fused_multi_transformer(ctx.x_context(), + xft_x, + xft_cache_kv, + xft_src_mask, + xft_ln_scale, + xft_ln_bias, + xft_qkvw, + xft_qkv_bias, + xft_out_linear_w, + xft_out_linear_bias, + xft_ffn_ln_scale, + xft_ffn_ln_bias, + xft_ffn1_w, + xft_ffn1_bias, + xft_ffn2_w, + xft_ffn2_bias, + param, + time_step_value, + &xft_out, + xft_cache_kv_out); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xft::fused_multi_transformer"); +#else + LOG(FATAL) << "fused_multi_transformer_xpu is not supported since it's not " + "compiled with XPU_XFT"; +#endif +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_multi_transformer_xpu, + XPU, + ALL_LAYOUT, + phi::fusion::FusedMultiTransformerXpuKernel, + float, + phi::dtype::float16) { + kernel->InputAt(20).SetBackend(phi::Backend::CPU); +} -- GitLab