未验证 提交 52e1742f 编写于 作者: M mayang002 提交者: GitHub

[xpu] fused_multi_transformer_xpu pass&kernel support (#51571)

上级 c36e3fd2
......@@ -142,6 +142,8 @@ if(WITH_XPU_XFT)
message(STATUS "Compile with XPU XFT!")
add_definitions(-DPADDLE_WITH_XPU_XFT)
set(XPU_XFT_INC_DIR "${XPU_INC_DIR}/xft")
include_directories(${XPU_XFT_INC_DIR})
set(XPU_XFT_LIB "${XPU_LIB_DIR}/${XPU_XFT_LIB_NAME}")
endif()
......
......@@ -235,6 +235,8 @@ if(WITH_XPU)
pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(delete_isolated_node_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif()
cc_library(
......@@ -493,4 +495,8 @@ if(WITH_XPU)
test_delete_isolated_node_pass
SRCS xpu/delete_isolated_node_pass_test.cc
DEPS delete_isolated_node_pass)
cc_test(
test_fused_multi_transformer_xpu_quant_pass
SRCS xpu/fused_multi_transformer_xpu_quant_pass_tester.cc
DEPS fused_multi_transformer_xpu_quant_pass)
endif()
......@@ -75,7 +75,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) {
1,
{2, -1, 16, 1024, 64},
0);
auto* out = layers.fused_multi_transformer(x,
auto outs = layers.fused_multi_transformer(x,
cache_kv,
src_mask,
qkv_w,
......@@ -93,7 +93,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) {
0.1,
1e-12);
x = out;
x = outs[0];
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
......@@ -126,7 +126,7 @@ TEST(FuseMultiTransformerLayerPass, decoder_fp) {
for (int i = 0; i < num_layers; ++i) {
auto* shape_out = layers.shape(src_mask);
auto* time_stamp = layers.slice(shape_out, {0}, {3}, {4});
auto* out = layers.fused_multi_transformer(x,
auto outs = layers.fused_multi_transformer(x,
cache_kv,
src_mask,
qkv_w,
......@@ -145,7 +145,7 @@ TEST(FuseMultiTransformerLayerPass, decoder_fp) {
1e-12,
time_stamp);
x = out;
x = outs[0];
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto param_scope = CreateParamScope();
......
......@@ -151,6 +151,15 @@ class Node {
var_desc_->SetName(new_name);
}
void RenameOp(const std::string& new_name) {
PADDLE_ENFORCE_EQ(
type_ == Type::kOperation && op_desc_,
true,
platform::errors::InvalidArgument("Node must be type of variable."));
name_ = new_name;
op_desc_->SetType(new_name);
}
int DescOrder() const { return desc_order_; }
int GetVarNodeBlockId() const {
......
......@@ -49,6 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = {
"fuse_multi_transformer_layer_pass",
"delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass",
"fused_multi_transformer_xpu_quant_pass",
"fc_xpu_fuse_pass",
"delete_op_device_pass"};
......
......@@ -571,33 +571,35 @@ struct Layers {
return out;
}
VarDesc* fused_multi_transformer(VarDesc* x,
VarDesc* cache_kv,
VarDesc* src_mask,
VarDesc* qkv_w,
VarDesc* qkv_bias,
VarDesc* out_linear_w,
VarDesc* out_linear_bias,
VarDesc* ffn1_w,
VarDesc* ffn1_bias,
VarDesc* ffn2_w,
VarDesc* ffn2_bias,
VarDesc* ln_scale,
VarDesc* ln_bias,
VarDesc* ffn_ln_scale,
VarDesc* ffn_ln_bias,
float epsilon,
float dropout_rate,
VarDesc* time_stamp = nullptr,
VarDesc* qkv_out_scale = nullptr,
VarDesc* out_linear_out_scale = nullptr,
VarDesc* ffn1_out_scale = nullptr,
VarDesc* ffn2_out_scale = nullptr,
std::vector<float> qkv_in_scale = {},
std::vector<float> out_linear_in_scale = {},
std::vector<float> ffn1_in_scale = {},
std::vector<float> ffn2_in_scale = {}) {
std::vector<VarDesc*> fused_multi_transformer(
VarDesc* x,
VarDesc* cache_kv,
VarDesc* src_mask,
VarDesc* qkv_w,
VarDesc* qkv_bias,
VarDesc* out_linear_w,
VarDesc* out_linear_bias,
VarDesc* ffn1_w,
VarDesc* ffn1_bias,
VarDesc* ffn2_w,
VarDesc* ffn2_bias,
VarDesc* ln_scale,
VarDesc* ln_bias,
VarDesc* ffn_ln_scale,
VarDesc* ffn_ln_bias,
float epsilon,
float dropout_rate,
VarDesc* time_stamp = nullptr,
VarDesc* qkv_out_scale = nullptr,
VarDesc* out_linear_out_scale = nullptr,
VarDesc* ffn1_out_scale = nullptr,
VarDesc* ffn2_out_scale = nullptr,
std::vector<float> qkv_in_scale = {},
std::vector<float> out_linear_in_scale = {},
std::vector<float> ffn1_in_scale = {},
std::vector<float> ffn2_in_scale = {}) {
VarDesc* out = lod_tensor(unique_name());
VarDesc* cache_kv_out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
std::string op_type = qkv_out_scale ? "fused_multi_transformer_int8"
: "fused_multi_transformer";
......@@ -623,6 +625,7 @@ struct Layers {
op->SetAttr("dropout_rate", dropout_rate);
op->SetAttr("epsilon", epsilon);
op->SetOutput("Out", {out->Name()});
op->SetOutput("CacheKVOut", {cache_kv_out->Name()});
if (time_stamp) {
op->SetInput("TimeStep", {time_stamp->Name()});
......@@ -638,7 +641,8 @@ struct Layers {
op->SetAttr("ffn1_in_scale", ffn1_in_scale);
op->SetAttr("ffn2_in_scale", ffn2_in_scale);
}
return out;
std::vector<VarDesc*> outs = {out, cache_kv_out};
return outs;
}
VarDesc* dequantize_linear(VarDesc* x,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct FusedMultiTransformerPattern : public PatternBase {
FusedMultiTransformerPattern(PDPattern* pattern,
const std::string& name_scope,
bool with_cache_kv,
bool with_pre_caches,
bool with_rotary_pos_emb,
bool with_time_step,
bool with_seq_lengths,
bool with_src_mask);
// declare operator node's name
PATTERN_DECL_NODE(fused_mt);
// declare variable node's name
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(ln_scale);
PATTERN_DECL_NODE(ln_bias);
PATTERN_DECL_NODE(qkv_w);
PATTERN_DECL_NODE(qkv_bias);
PATTERN_DECL_NODE(cache_kv);
PATTERN_DECL_NODE(pre_caches);
PATTERN_DECL_NODE(rotary_pos_emb);
PATTERN_DECL_NODE(time_step);
PATTERN_DECL_NODE(seq_lengths);
PATTERN_DECL_NODE(src_mask);
PATTERN_DECL_NODE(out_linear_w);
PATTERN_DECL_NODE(out_linear_bias);
PATTERN_DECL_NODE(ffn_ln_scale);
PATTERN_DECL_NODE(ffn_ln_bias);
PATTERN_DECL_NODE(ffn1_w);
PATTERN_DECL_NODE(ffn1_bias);
PATTERN_DECL_NODE(ffn2_w);
PATTERN_DECL_NODE(ffn2_bias);
PATTERN_DECL_NODE(cache_kv_out);
PATTERN_DECL_NODE(out);
private:
bool with_cache_kv_{false};
bool with_pre_caches_{false};
bool with_rotary_pos_emb_{false};
bool with_time_step_{false};
bool with_seq_lengths_{false};
bool with_src_mask_{false};
};
FusedMultiTransformerPattern::FusedMultiTransformerPattern(
PDPattern* pattern,
const std::string& name_scope,
bool with_cache_kv,
bool with_pre_caches,
bool with_rotary_pos_emb,
bool with_time_step,
bool with_seq_lengths,
bool with_src_mask)
: PatternBase(pattern, name_scope, name_scope),
with_cache_kv_(with_cache_kv),
with_pre_caches_(with_pre_caches),
with_rotary_pos_emb_(with_rotary_pos_emb),
with_time_step_(with_time_step),
with_seq_lengths_(with_seq_lengths),
with_src_mask_(with_src_mask) {
std::string op_type = "fused_multi_transformer";
auto* fused_mt = pattern->NewNode(fused_mt_repr())->assert_is_op(op_type);
// inputs and outputs
auto* x = pattern->NewNode(x_repr())
->assert_is_op_input(op_type, "X")
->assert_var_not_persistable();
auto* cache_kv_out = pattern->NewNode(cache_kv_out_repr())
->assert_is_op_output(op_type, "CacheKVOut")
->assert_var_not_persistable();
auto* out = pattern->NewNode(out_repr())
->assert_is_op_output(op_type, "Out")
->assert_var_not_persistable();
// weights and biases
auto* ln_scale = pattern->NewNode(ln_scale_repr())
->assert_is_op_input(op_type, "LnScale")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 1;
});
auto* ln_bias = pattern->NewNode(ln_bias_repr())
->assert_is_op_input(op_type, "LnBias")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 1;
});
auto* qkv_w = pattern->NewNode(qkv_w_repr())
->assert_is_op_input(op_type, "QKVW")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 4;
});
auto* qkv_bias = pattern->NewNode(qkv_bias_repr())
->assert_is_op_input(op_type, "QKVBias")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 3;
});
auto* out_linear_w = pattern->NewNode(out_linear_w_repr())
->assert_is_op_input(op_type, "OutLinearW")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 2;
});
auto* out_linear_bias = pattern->NewNode(out_linear_bias_repr())
->assert_is_op_input(op_type, "OutLinearBias")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 1;
});
auto* ffn_ln_scale = pattern->NewNode(ffn_ln_scale_repr())
->assert_is_op_input(op_type, "FFNLnScale")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 1;
});
auto* ffn_ln_bias = pattern->NewNode(ffn_ln_bias_repr())
->assert_is_op_input(op_type, "FFNLnBias")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 1;
});
auto* ffn1_w = pattern->NewNode(ffn1_w_repr())
->assert_is_op_input(op_type, "FFN1Weight")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 2;
});
auto* ffn1_bias = pattern->NewNode(ffn1_bias_repr())
->assert_is_op_input(op_type, "FFN1Bias")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 1;
});
auto* ffn2_w = pattern->NewNode(ffn2_w_repr())
->assert_is_op_input(op_type, "FFN2Weight")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 2;
});
auto* ffn2_bias = pattern->NewNode(ffn2_bias_repr())
->assert_is_op_input(op_type, "FFN2Bias")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 1;
});
std::vector<PDNode*> input_vars{x,
ln_scale,
ln_bias,
qkv_w,
qkv_bias,
out_linear_w,
out_linear_bias,
ffn_ln_scale,
ffn_ln_bias,
ffn1_w,
ffn1_bias,
ffn2_w,
ffn2_bias};
std::vector<PDNode*> output_vars{cache_kv_out, out};
// optional node
PDNode* cache_kv = nullptr;
PDNode* pre_caches = nullptr;
PDNode* rotary_pos_emb = nullptr;
PDNode* time_step = nullptr;
PDNode* seq_lengths = nullptr;
PDNode* src_mask = nullptr;
if (with_cache_kv_) {
cache_kv = pattern->NewNode(cache_kv_repr())
->assert_is_op_input(op_type, "CacheKV")
->assert_var_not_persistable();
input_vars.push_back(cache_kv);
}
if (with_pre_caches_) {
pre_caches = pattern->NewNode(pre_caches_repr())
->assert_is_op_input(op_type, "PreCaches")
->assert_var_not_persistable();
input_vars.push_back(pre_caches);
}
if (with_rotary_pos_emb_) {
rotary_pos_emb = pattern->NewNode(rotary_pos_emb_repr())
->assert_is_op_input(op_type, "RotaryPosEmb")
->assert_var_not_persistable();
input_vars.push_back(rotary_pos_emb);
}
if (with_time_step_) {
time_step = pattern->NewNode(time_step_repr())
->assert_is_op_input(op_type, "TimeStep")
->assert_var_not_persistable();
input_vars.push_back(time_step);
}
if (with_seq_lengths_) {
seq_lengths = pattern->NewNode(seq_lengths_repr())
->assert_is_op_input(op_type, "SeqLengths")
->assert_var_not_persistable();
input_vars.push_back(seq_lengths);
}
if (with_src_mask_) {
src_mask = pattern->NewNode(src_mask_repr())
->assert_is_op_input(op_type, "SrcMask")
->assert_var_not_persistable();
input_vars.push_back(src_mask);
}
fused_mt->LinksFrom(input_vars).LinksTo(output_vars);
}
} // namespace patterns
/*
1. transpose and quantify the weights of fused_multi_transformer op from fp32 to
int16
*/
class FusedMultiTransformerXPUQuantPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
int ApplyImpl(ir::Graph* graph,
bool with_cache_kv,
bool with_pre_caches,
bool with_rotary_pos_emb,
bool with_time_step,
bool with_seq_lengths,
bool with_src_mask) const;
const std::string name_scope_{"fused_multi_transformer_xpu_quant_pass"};
};
void FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
VLOG(3) << "in FusedMultiTransformerXPUQuantPass::ApplyImpl";
int found_subgraph_count = 0;
for (bool with_time_step : {true, false}) {
found_subgraph_count +=
ApplyImpl(graph, true, false, false, with_time_step, false, true);
}
AddStatis(found_subgraph_count);
}
int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
bool with_cache_kv,
bool with_pre_caches,
bool with_rotary_pos_emb,
bool with_time_step,
bool with_seq_lengths,
bool with_src_mask) const {
GraphPatternDetector gpd;
patterns::FusedMultiTransformerPattern pattern(gpd.mutable_pattern(),
name_scope_,
with_cache_kv,
with_pre_caches,
with_rotary_pos_emb,
with_time_step,
with_seq_lengths,
with_src_mask);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle FusedMultiTransformerXPUQuantPass fuse";
GET_IR_NODE(x);
GET_IR_NODE(ln_scale);
GET_IR_NODE(ln_bias);
GET_IR_NODE(qkv_w);
GET_IR_NODE(qkv_bias);
GET_IR_NODE(cache_kv);
GET_IR_NODE(pre_caches);
GET_IR_NODE(rotary_pos_emb);
GET_IR_NODE(time_step);
GET_IR_NODE(seq_lengths);
GET_IR_NODE(src_mask);
GET_IR_NODE(out_linear_w);
GET_IR_NODE(out_linear_bias);
GET_IR_NODE(ffn_ln_scale);
GET_IR_NODE(ffn_ln_bias);
GET_IR_NODE(ffn1_w);
GET_IR_NODE(ffn1_bias);
GET_IR_NODE(ffn2_w);
GET_IR_NODE(ffn2_bias);
GET_IR_NODE(cache_kv_out);
GET_IR_NODE(out);
GET_IR_NODE(fused_mt);
auto* block = fused_mt->Op()->Block();
auto* scope = param_scope();
// quant weight nodes
// w_nodes_vec: [QKVW, OutLinearW, FFN1Weight, FFN2Weight]
std::vector<std::vector<Node*>> w_nodes_vec(4);
std::vector<std::vector<Node*>> w_int16_nodes_vec(4);
std::vector<std::vector<Node*>> w_max_nodes_vec(4);
std::vector<std::vector<std::string>> w_int16_names_vec(4);
std::vector<std::vector<std::string>> w_max_names_vec(4);
auto quant_func = [&](const std::string& input_name,
std::vector<Node*>* w_nodes,
std::vector<Node*>* w_int16_nodes,
std::vector<Node*>* w_max_nodes,
std::vector<std::string>* w_int16_names,
std::vector<std::string>* w_max_names,
bool need_transpose) {
typedef int16_t TW;
auto w_names = fused_mt->Op()->Input(input_name);
for (auto w_name : w_names) {
Node* w_node = FindNodeWithName(graph, w_name);
Node* w_int16 = nullptr;
Node* w_max = nullptr;
PADDLE_ENFORCE_NE(
w_node,
nullptr,
platform::errors::Fatal("w node should not be nullptr"));
PrepareWeight<TW>(
graph, scope, block, w_node, &w_int16, &w_max, need_transpose);
w_nodes->push_back(w_node);
w_int16_nodes->push_back(w_int16);
w_max_nodes->push_back(w_max);
}
for (size_t i = 0; i < w_names.size(); ++i) {
w_int16_names->push_back(w_int16_nodes->at(i)->Name());
w_max_names->push_back(w_max_nodes->at(i)->Name());
}
PADDLE_ENFORCE_EQ(
w_names.size(),
w_nodes->size(),
platform::errors::Fatal(
"The size of w_names(%d) should be equal to w_nodes(%d)",
static_cast<int>(w_names.size()),
static_cast<int>(w_nodes->size())));
PADDLE_ENFORCE_EQ(
w_names.size(),
w_int16_nodes->size(),
platform::errors::Fatal(
"The size of w_names(%d) should be equal to w_int16_nodes(%d)",
static_cast<int>(w_names.size()),
static_cast<int>(w_int16_nodes->size())));
PADDLE_ENFORCE_EQ(
w_names.size(),
w_max_nodes->size(),
platform::errors::Fatal(
"The size of w_names(%d) should be equal to w_max_nodes(%d)",
static_cast<int>(w_names.size()),
static_cast<int>(w_max_nodes->size())));
PADDLE_ENFORCE_EQ(
w_names.size(),
w_int16_names->size(),
platform::errors::Fatal(
"The size of w_names(%d) should be equal to w_int16_names(%d)",
static_cast<int>(w_names.size()),
static_cast<int>(w_int16_names->size())));
PADDLE_ENFORCE_EQ(
w_names.size(),
w_max_names->size(),
platform::errors::Fatal(
"The size of w_names(%d) should be equal to w_max_names(%d)",
static_cast<int>(w_names.size()),
static_cast<int>(w_max_names->size())));
};
quant_func("QKVW",
&(w_nodes_vec[0]),
&(w_int16_nodes_vec[0]),
&(w_max_nodes_vec[0]),
&(w_int16_names_vec[0]),
&(w_max_names_vec[0]),
false);
quant_func("OutLinearW",
&(w_nodes_vec[1]),
&(w_int16_nodes_vec[1]),
&(w_max_nodes_vec[1]),
&(w_int16_names_vec[1]),
&(w_max_names_vec[1]),
true);
quant_func("FFN1Weight",
&(w_nodes_vec[2]),
&(w_int16_nodes_vec[2]),
&(w_max_nodes_vec[2]),
&(w_int16_names_vec[2]),
&(w_max_names_vec[2]),
true);
quant_func("FFN2Weight",
&(w_nodes_vec[3]),
&(w_int16_nodes_vec[3]),
&(w_max_nodes_vec[3]),
&(w_int16_names_vec[3]),
&(w_max_names_vec[3]),
true);
// cast some nodes to fp32 nodes
std::vector<Node*> fp32_nodes;
auto cast_tofp32_func = [&](const std::string& input_name) {
auto names = fused_mt->Op()->Input(input_name);
for (auto name : names) {
auto* curr_tensor = scope->Var(name)->GetMutable<phi::DenseTensor>();
PADDLE_ENFORCE_NE(
curr_tensor,
nullptr,
platform::errors::Fatal("tensor node should not be nullptr"));
CastToFp32(curr_tensor);
Node* curr_node = FindNodeWithName(graph, name);
fp32_nodes.push_back(curr_node);
}
};
cast_tofp32_func("LnScale");
cast_tofp32_func("LnBias");
cast_tofp32_func("QKVBias");
cast_tofp32_func("OutLinearBias");
cast_tofp32_func("FFNLnScale");
cast_tofp32_func("FFNLnBias");
cast_tofp32_func("FFN1Bias");
cast_tofp32_func("FFN2Bias");
// Generate fused_multi_transformer_xpu op inplace
fused_mt->RenameOp("fused_multi_transformer_xpu");
framework::OpDesc* fused_mt_xpu_op_desc = fused_mt->Op();
fused_mt_xpu_op_desc->SetType("fused_multi_transformer_xpu");
std::unordered_map<std::string, std::vector<std::string>> name_caches;
for (auto key : fused_mt_xpu_op_desc->InputNames()) {
name_caches.insert({key, fused_mt_xpu_op_desc->Input(key)});
}
for (auto key : fused_mt_xpu_op_desc->OutputNames()) {
name_caches.insert({key, fused_mt_xpu_op_desc->Output(key)});
}
fused_mt_xpu_op_desc->MutableInputs()->clear();
fused_mt_xpu_op_desc->MutableOutputs()->clear();
fused_mt_xpu_op_desc->SetInput("x", name_caches.at("X"));
fused_mt_xpu_op_desc->SetInput("ln_scale", name_caches.at("LnScale"));
fused_mt_xpu_op_desc->SetInput("ln_bias", name_caches.at("LnBias"));
fused_mt_xpu_op_desc->SetInput("qkv_bias", name_caches.at("QKVBias"));
if (cache_kv) {
fused_mt_xpu_op_desc->SetInput("cache_kv", name_caches.at("CacheKV"));
}
if (pre_caches) {
fused_mt_xpu_op_desc->SetInput("pre_caches", name_caches.at("PreCaches"));
}
if (rotary_pos_emb) {
fused_mt_xpu_op_desc->SetInput("rotary_pos_emb",
name_caches.at("RotaryPosEmb"));
}
if (time_step) {
fused_mt_xpu_op_desc->SetInput("time_step", name_caches.at("TimeStep"));
}
if (seq_lengths) {
fused_mt_xpu_op_desc->SetInput("seq_lengths",
name_caches.at("SeqLengths"));
}
if (src_mask) {
fused_mt_xpu_op_desc->SetInput("src_mask", name_caches.at("SrcMask"));
}
fused_mt_xpu_op_desc->SetInput("out_linear_bias",
name_caches.at("OutLinearBias"));
fused_mt_xpu_op_desc->SetInput("ffn_ln_scale",
name_caches.at("FFNLnScale"));
fused_mt_xpu_op_desc->SetInput("ffn_ln_bias", name_caches.at("FFNLnBias"));
fused_mt_xpu_op_desc->SetInput("ffn1_bias", name_caches.at("FFN1Bias"));
fused_mt_xpu_op_desc->SetInput("ffn2_bias", name_caches.at("FFN2Bias"));
fused_mt_xpu_op_desc->SetOutput("cache_kv_out",
name_caches.at("CacheKVOut"));
fused_mt_xpu_op_desc->SetOutput("out", name_caches.at("Out"));
fused_mt_xpu_op_desc->SetInput("qkvw", w_int16_names_vec[0]);
fused_mt_xpu_op_desc->SetInput("qkvw_max", w_max_names_vec[0]);
fused_mt_xpu_op_desc->SetInput("out_linear_w", w_int16_names_vec[1]);
fused_mt_xpu_op_desc->SetInput("out_linear_wmax", w_max_names_vec[1]);
fused_mt_xpu_op_desc->SetInput("ffn1_weight", w_int16_names_vec[2]);
fused_mt_xpu_op_desc->SetInput("ffn1_weight_max", w_max_names_vec[2]);
fused_mt_xpu_op_desc->SetInput("ffn2_weight", w_int16_names_vec[3]);
fused_mt_xpu_op_desc->SetInput("ffn2_weight_max", w_max_names_vec[3]);
if (!fused_mt_xpu_op_desc->HasAttr("rotary_emb_dims")) {
fused_mt_xpu_op_desc->SetAttr("rotary_emb_dims", 0);
}
// unlink QKVW/OutLinearW/FFN1Weight/FFN2Weight from fused_mt_xpu
for (auto nodes : w_nodes_vec) {
for (auto node : nodes) {
IR_NODE_UNLINK(node, fused_mt);
}
}
// link int16 format of QKVW/OutLinearW/FFN1Weight/FFN2Weight to
// fused_mt_xpu
for (auto nodes : w_int16_nodes_vec) {
for (auto node : nodes) {
IR_NODE_LINK_TO(node, fused_mt);
}
}
// link QKVWMax/OutLinearWMax/FFN1WeightMax/FFN2WeightMax to fused_mt_xpu
for (auto nodes : w_max_nodes_vec) {
for (auto node : nodes) {
IR_NODE_LINK_TO(node, fused_mt);
}
}
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fused_multi_transformer_xpu_quant_pass,
paddle::framework::ir::FusedMultiTransformerXPUQuantPass);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#define DEF_INPUT_DATA \
Layers layers; \
auto* x = layers.data("x", {1, 128, 1024}); \
auto* src_mask = layers.data("src_mask", {1, 16, 128, 128}); \
auto* ln_scale = layers.data("ln_scale", {1024}, true); \
auto* ln_bias = layers.data("ln_bias", {1024}, true); \
auto* qkv_w = layers.data("qkv_w", {3, 16, 64, 1024}, true); \
auto* qkv_bias = layers.data("qkv_bias", {3, 16, 64}, true); \
auto* out_linear_w = layers.data("out_linear_w", {1024, 1024}, true); \
auto* out_linear_bias = layers.data("out_linear_bias", {1024}, true); \
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); \
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); \
auto* ffn1_w = layers.data("ffn1_w", {1024, 4096}, true); \
auto* ffn1_bias = layers.data("ffn1_bias", {4096}, true); \
auto* ffn2_w = layers.data("ffn2_w", {4096, 1024}, true); \
auto* ffn2_bias = layers.data("ffn2_bias", {1024}, true);
namespace paddle {
namespace framework {
namespace ir {
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "ln_scale", {1024});
AddVarToScope(param_scope, "ln_bias", {1024});
AddVarToScope(param_scope, "ffn_ln_scale", {1024});
AddVarToScope(param_scope, "ffn_ln_bias", {1024});
AddVarToScope(param_scope, "qkv_w", {3, 16, 64, 1024});
AddVarToScope(param_scope, "out_linear_w", {1024, 1024});
AddVarToScope(param_scope, "ffn1_w", {1024, 4096});
AddVarToScope(param_scope, "ffn2_w", {4096, 1024});
AddVarToScope(param_scope, "qkv_bias", {3072});
AddVarToScope(param_scope, "out_linear_bias", {1024});
AddVarToScope(param_scope, "ffn1_bias", {4096});
AddVarToScope(param_scope, "ffn2_bias", {1024});
return param_scope;
}
TEST(FusedMultiTransformerXPUQuantPass, context_stage) {
DEF_INPUT_DATA
auto* cache_kv = layers.fill_constant_batch_size_like(
x,
static_cast<int>(proto::VarType::FP32),
0,
1,
{2, -1, 16, 1024, 64},
0);
layers.fused_multi_transformer(x,
cache_kv,
src_mask,
qkv_w,
qkv_bias,
out_linear_w,
out_linear_bias,
ffn1_w,
ffn1_bias,
ffn2_w,
ffn2_bias,
ln_scale,
ln_bias,
ffn_ln_scale,
ffn_ln_bias,
0.1,
1e-12);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass =
PassRegistry::Instance().Get("fused_multi_transformer_xpu_quant_pass");
if (pass.get() == nullptr) {
LOG(INFO) << "get fused_multi_transformer_xpu_quant_pass failed";
}
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer_xpu");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(
num_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fuse_multi_transformer_layer_pass, "
"The node num in graph should be 1, but the result is %d",
num_nodes_after));
}
TEST(FusedMultiTransformerXPUQuantPass, decoder_stage) {
DEF_INPUT_DATA
auto* cache_kv = layers.fill_constant_batch_size_like(
x,
static_cast<int>(proto::VarType::FP32),
0,
1,
{2, -1, 16, 1024, 64},
0);
auto* time_step = layers.data("time_step", {1});
layers.fused_multi_transformer(x,
cache_kv,
src_mask,
qkv_w,
qkv_bias,
out_linear_w,
out_linear_bias,
ffn1_w,
ffn1_bias,
ffn2_w,
ffn2_bias,
ln_scale,
ln_bias,
ffn_ln_scale,
ffn_ln_bias,
0.1,
1e-12,
time_step);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass =
PassRegistry::Instance().Get("fused_multi_transformer_xpu_quant_pass");
if (pass.get() == nullptr) {
LOG(INFO) << "get fused_multi_transformer_xpu_quant_pass failed";
}
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer_xpu");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(
num_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fuse_multi_transformer_layer_pass, "
"The node num in graph should be 1, but the result is %d",
num_nodes_after));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fused_multi_transformer_xpu_quant_pass);
......@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"embedding_with_eltwise_add_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
"fused_multi_transformer_xpu_quant_pass",
"fc_xpu_fuse_pass",
"link_xpu_op_max_pass",
"delete_op_device_pass",
......
......@@ -47,6 +47,16 @@
param : [x, axis, keepdim, reduce_all]
backward : frobenius_norm_grad
- op : fused_multi_transformer_xpu
args : (Tensor x, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] qkvw, Tensor[] qkvw_max, Tensor[] qkv_bias, Tensor[] out_linear_w, Tensor[] out_linear_wmax, Tensor[] out_linear_bias, Tensor[] ffn_ln_scale, Tensor[] ffn_ln_bias, Tensor[] ffn1_weight, Tensor[] ffn1_weight_max, Tensor[] ffn1_bias, Tensor[] ffn2_weight, Tensor[] ffn2_weight_max, Tensor[] ffn2_bias, Tensor[] cache_kv, Tensor[] pre_caches, Tensor rotary_pos_emb, Tensor time_step, Tensor seq_lengths, Tensor src_mask, bool pre_layer_norm, int rotary_emb_dims, float epsilon, float dropout_rate, bool is_test, str dropout_implementation, str act_method, bool trans_qkvw, int ring_id)
output : Tensor(out), Tensor[](cache_kv_out){out_linear_w.size()}
infer_meta :
func : FusedMultiTransformerXpuInferMeta
kernel :
func : fused_multi_transformer_xpu
data_type : x
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask
- op : generate_sequence_xpu
args : (Tensor x, DataType dtype)
output : Tensor
......
......@@ -331,6 +331,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"fused_multi_transformer_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"unfold",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"unfold_grad",
......
......@@ -114,4 +114,108 @@ void MultiEncoderXPUInferMeta(
}
}
void FusedMultiTransformerXpuInferMeta(
const MetaTensor& x,
const std::vector<const MetaTensor*>& ln_scale,
const std::vector<const MetaTensor*>& ln_bias,
const std::vector<const MetaTensor*>& qkvw,
const std::vector<const MetaTensor*>& qkvw_max,
const std::vector<const MetaTensor*>& qkv_bias,
const std::vector<const MetaTensor*>& out_linear_w,
const std::vector<const MetaTensor*>& out_linear_wmax,
const std::vector<const MetaTensor*>& out_linear_bias,
const std::vector<const MetaTensor*>& ffn_ln_scale,
const std::vector<const MetaTensor*>& ffn_ln_bias,
const std::vector<const MetaTensor*>& ffn1_weight,
const std::vector<const MetaTensor*>& ffn1_weight_max,
const std::vector<const MetaTensor*>& ffn1_bias,
const std::vector<const MetaTensor*>& ffn2_weight,
const std::vector<const MetaTensor*>& ffn2_weight_max,
const std::vector<const MetaTensor*>& ffn2_bias,
const std::vector<const MetaTensor*>& cache_kv,
const std::vector<const MetaTensor*>& pre_caches,
const std::vector<const MetaTensor*>& rotary_pos_emb,
const std::vector<const MetaTensor*>& time_step,
const std::vector<const MetaTensor*>& seq_lengths,
const std::vector<const MetaTensor*>& src_mask,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
float dropout_rate,
bool is_test,
const std::string& dropout_implementation,
const std::string& act_method,
bool trans_qkvw,
int ring_id,
MetaTensor* out,
std::vector<MetaTensor*> cache_kv_out) {
auto x_dim = x.dims();
auto y_dim = qkvw[0]->dims();
PADDLE_ENFORCE_EQ(
x_dim.size(),
3,
phi::errors::InvalidArgument("The dimensions of x must be 3"
"(batch_size, seq_len, dim_embed),"
"but received dimensions of"
"Input is [%d]",
x_dim.size()));
PADDLE_ENFORCE_EQ(
y_dim.size(),
4,
phi::errors::InvalidArgument("The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"but received dimensions of"
"Input is [%d]",
y_dim.size()));
PADDLE_ENFORCE_EQ(
x_dim[2],
trans_qkvw ? y_dim[3] : y_dim[0],
phi::errors::InvalidArgument(
"ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is "
"true) or y_dim[0](trans_qkvw is false)"
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]",
x_dim,
y_dim));
if (cache_kv.size() > 0) {
const auto& c_dim = cache_kv[0]->dims();
PADDLE_ENFORCE_EQ(
c_dim.size(),
5,
phi::errors::InvalidArgument("The CacheKV must be 5 dims, but got %d",
c_dim.size()));
PADDLE_ENFORCE_EQ(c_dim[0],
2,
phi::errors::InvalidArgument(
"The first dim of CacheKV must be 2, but got %d",
c_dim[0])); // 2
PADDLE_ENFORCE_EQ(c_dim[1],
x_dim[0],
phi::errors::InvalidArgument(
"The second dim of CacheKV must be equal with "
"batch size %d, but got %d",
x_dim[0],
c_dim[1])); // batch_size
PADDLE_ENFORCE_EQ(c_dim[2],
trans_qkvw ? y_dim[1] : y_dim[2],
phi::errors::InvalidArgument(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d",
trans_qkvw ? y_dim[1] : y_dim[2],
c_dim[2])); // num_head
PADDLE_ENFORCE_EQ(c_dim[4],
trans_qkvw ? y_dim[2] : y_dim[3],
phi::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d",
trans_qkvw ? y_dim[2] : y_dim[3],
c_dim[4])); // head_size
}
out->set_dims(x_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
} // namespace phi
......@@ -66,4 +66,39 @@ void MultiEncoderXPUInferMeta(
MetaTensor* x_fp16,
MetaTensor* out_fp16);
void FusedMultiTransformerXpuInferMeta(
const MetaTensor& x,
const std::vector<const MetaTensor*>& ln_scale,
const std::vector<const MetaTensor*>& ln_bias,
const std::vector<const MetaTensor*>& qkvw,
const std::vector<const MetaTensor*>& qkvw_max,
const std::vector<const MetaTensor*>& qkv_bias,
const std::vector<const MetaTensor*>& out_linear_w,
const std::vector<const MetaTensor*>& out_linear_wmax,
const std::vector<const MetaTensor*>& out_linear_bias,
const std::vector<const MetaTensor*>& ffn_ln_scale,
const std::vector<const MetaTensor*>& ffn_ln_bias,
const std::vector<const MetaTensor*>& ffn1_weight,
const std::vector<const MetaTensor*>& ffn1_weight_max,
const std::vector<const MetaTensor*>& ffn1_bias,
const std::vector<const MetaTensor*>& ffn2_weight,
const std::vector<const MetaTensor*>& ffn2_weight_max,
const std::vector<const MetaTensor*>& ffn2_bias,
const std::vector<const MetaTensor*>& cache_kv,
const std::vector<const MetaTensor*>& pre_caches,
const std::vector<const MetaTensor*>& rotary_pos_emb,
const std::vector<const MetaTensor*>& time_step,
const std::vector<const MetaTensor*>& seq_lengths,
const std::vector<const MetaTensor*>& src_mask,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
float dropout_rate,
bool is_test,
const std::string& dropout_implementation,
const std::string& act_method,
bool trans_qkvw,
int ring_id,
MetaTensor* out,
std::vector<MetaTensor*> cache_kv_out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/memcpy_kernel.h"
#ifdef PADDLE_WITH_XPU_XFT
#include "models/fused_multi_transformer_op.h"
namespace xft = baidu::xpu::xft;
#endif
namespace phi {
namespace fusion {
template <typename T, typename Context>
void FusedMultiTransformerXpuKernel(
const Context& ctx,
const DenseTensor& xx,
const std::vector<const DenseTensor*>& ln_scale,
const std::vector<const DenseTensor*>& ln_bias,
const std::vector<const DenseTensor*>& qkvw,
const std::vector<const DenseTensor*>& qkvw_max,
const std::vector<const DenseTensor*>& qkv_bias,
const std::vector<const DenseTensor*>& out_linear_w,
const std::vector<const DenseTensor*>& out_linear_wmax,
const std::vector<const DenseTensor*>& out_linear_bias,
const std::vector<const DenseTensor*>& ffn_ln_scale,
const std::vector<const DenseTensor*>& ffn_ln_bias,
const std::vector<const DenseTensor*>& ffn1_weight,
const std::vector<const DenseTensor*>& ffn1_weight_max,
const std::vector<const DenseTensor*>& ffn1_bias,
const std::vector<const DenseTensor*>& ffn2_weight,
const std::vector<const DenseTensor*>& ffn2_weight_max,
const std::vector<const DenseTensor*>& ffn2_bias,
const paddle::optional<std::vector<const DenseTensor*>>& cache_kv,
const paddle::optional<std::vector<const DenseTensor*>>& pre_caches,
const paddle::optional<DenseTensor>& rotary_pos_emb,
const paddle::optional<DenseTensor>& time_step,
const paddle::optional<DenseTensor>& seq_lengths,
const paddle::optional<DenseTensor>& src_mask,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
float dropout_rate,
bool is_test,
const std::string& dropout_implementation,
const std::string& act_method,
bool trans_qkvw,
int ring_id,
DenseTensor* out,
std::vector<DenseTensor*> cache_kv_out) {
#ifdef PADDLE_WITH_XPU_XFT
using XPUTypeT = typename XPUTypeTrait<T>::Type;
PADDLE_ENFORCE_EQ(pre_layer_norm,
true,
phi::errors::PreconditionNotMet(
"Only support pre_layer_norm = true at now."));
PADDLE_ENFORCE_EQ(
seq_lengths.get_ptr(),
nullptr,
phi::errors::PreconditionNotMet("seq_lengths not support at now."));
PADDLE_ENFORCE_EQ(
rotary_pos_emb.get_ptr(),
nullptr,
phi::errors::PreconditionNotMet("rotary_pos_emb not support at now."));
PADDLE_ENFORCE_EQ(
pre_caches.get_ptr(),
nullptr,
phi::errors::PreconditionNotMet("pre_caches not support at now."));
PADDLE_ENFORCE_NE(
src_mask.get_ptr(),
nullptr,
phi::errors::PreconditionNotMet("src_mask should not be nullptr."));
PADDLE_ENFORCE_EQ(trans_qkvw,
true,
phi::errors::PreconditionNotMet(
"Only support trans_qkvw == true at now."));
const auto x_dims = xx.dims();
int seq_len = x_dims[1];
const auto qkv_w_dims = qkvw[0]->dims();
int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2];
int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3];
int time_step_value = -1;
if (time_step) {
PADDLE_ENFORCE_EQ(time_step.get_ptr()->place(),
phi::CPUPlace(),
phi::errors::PreconditionNotMet(
"The place of input(time_step) must be CPUPlace."));
// cache_seq_len
time_step_value = time_step.get_ptr()->data<int>()[0];
PADDLE_ENFORCE_GT(
time_step_value,
0,
phi::errors::PreconditionNotMet(
"The value of time_step must > 0, but now is %d", time_step_value));
PADDLE_ENFORCE_EQ(
seq_len,
1,
phi::errors::PreconditionNotMet(
"In decode stage, the seq_len of input must be 1, but now is %d",
seq_len));
}
XPUTypeT* x_data = reinterpret_cast<XPUTypeT*>(const_cast<T*>(xx.data<T>()));
XPUTypeT* src_mask_data = reinterpret_cast<XPUTypeT*>(
const_cast<T*>(src_mask.get_ptr()->data<T>()));
auto* out_data = reinterpret_cast<XPUTypeT*>(ctx.template Alloc<T>(out));
auto src_mask_dims = src_mask.get_ptr()->dims();
auto out_dims = out->dims();
auto xft_x = xft::xftTensor<XPUTypeT, 3>(
x_data, std::array<int64_t, 3>{x_dims[0], x_dims[1], x_dims[2]});
// TODO(mayang02): xft support mask.dtype = float16
xpu::ctx_guard RAII_GUARD(ctx.x_context());
float* src_mask_fp32_data =
RAII_GUARD.alloc<float>(src_mask.get_ptr()->numel());
int r = xpu::cast<XPUTypeT, float>(ctx.x_context(),
src_mask_data,
src_mask_fp32_data,
src_mask.get_ptr()->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu::cast");
auto xft_src_mask =
xft::xftTensor<float, 4>(src_mask_fp32_data,
std::array<int64_t, 4>{src_mask_dims[0],
src_mask_dims[1],
src_mask_dims[2],
src_mask_dims[3]});
auto xft_out = xft::xftTensor<XPUTypeT, 3>(
out_data, std::array<int64_t, 3>{out_dims[0], out_dims[1], out_dims[2]});
typedef int16_t TW;
std::vector<xft::xftVec<float>> xft_ln_scale;
std::vector<xft::xftVec<float>> xft_ln_bias;
std::vector<xft::xftMat<TW>> xft_qkvw;
std::vector<xft::xftVec<float>> xft_qkv_bias;
std::vector<xft::xftMat<TW>> xft_out_linear_w;
std::vector<xft::xftVec<float>> xft_out_linear_bias;
std::vector<xft::xftVec<float>> xft_ffn_ln_scale;
std::vector<xft::xftVec<float>> xft_ffn_ln_bias;
std::vector<xft::xftMat<TW>> xft_ffn1_w;
std::vector<xft::xftVec<float>> xft_ffn1_bias;
std::vector<xft::xftMat<TW>> xft_ffn2_w;
std::vector<xft::xftVec<float>> xft_ffn2_bias;
std::vector<xft::xftTensor<XPUTypeT, 5>> xft_cache_kv;
std::vector<xft::xftTensor<XPUTypeT, 5>> xft_cache_kv_out;
int layers = qkvw.size();
for (int i = 0; i < layers; ++i) {
// step1. layer_norm
xft_ln_scale.emplace_back(const_cast<float*>(ln_scale[i]->data<float>()),
std::array<int64_t, 1>{ln_scale[i]->dims()[0]});
xft_ln_bias.emplace_back(const_cast<float*>(ln_bias[i]->data<float>()),
std::array<int64_t, 1>{ln_bias[i]->dims()[0]});
// step2. qkv
auto qkvw_dims = qkvw[i]->dims();
xft_qkvw.emplace_back(
const_cast<TW*>(qkvw[i]->data<TW>()),
const_cast<float*>(qkvw_max[i]->data<float>()),
std::array<int64_t, 2>{qkvw_dims[0] * qkvw_dims[1] * qkvw_dims[2],
qkvw_dims[3]});
auto qkvb_dims = qkv_bias[i]->dims();
xft_qkv_bias.emplace_back(
const_cast<float*>(qkv_bias[i]->data<float>()),
std::array<int64_t, 1>{qkvb_dims[0] * qkvb_dims[1] * qkvb_dims[2]});
// attn out
auto outw_dims = out_linear_w[i]->dims();
xft_out_linear_w.emplace_back(
const_cast<TW*>(out_linear_w[i]->data<TW>()),
const_cast<float*>(out_linear_wmax[i]->data<float>()),
std::array<int64_t, 2>{outw_dims[0], outw_dims[1]});
xft_out_linear_bias.emplace_back(
const_cast<float*>(out_linear_bias[i]->data<float>()),
std::array<int64_t, 1>{out_linear_bias[i]->dims()[0]});
// ffn ln
xft_ffn_ln_scale.emplace_back(
const_cast<float*>(ffn_ln_scale[i]->data<float>()),
std::array<int64_t, 1>{ffn_ln_scale[i]->dims()[0]});
xft_ffn_ln_bias.emplace_back(
const_cast<float*>(ffn_ln_bias[i]->data<float>()),
std::array<int64_t, 1>{ffn_ln_bias[i]->dims()[0]});
// ffn1
auto ffn1w_dims = ffn1_weight[i]->dims();
xft_ffn1_w.emplace_back(
const_cast<TW*>(ffn1_weight[i]->data<TW>()),
const_cast<float*>(ffn1_weight_max[i]->data<float>()),
std::array<int64_t, 2>{ffn1w_dims[0], ffn1w_dims[1]});
xft_ffn1_bias.emplace_back(const_cast<float*>(ffn1_bias[i]->data<float>()),
std::array<int64_t, 1>{ffn1_bias[i]->dims()[0]});
// ffn2
auto ffn2w_dims = ffn2_weight[i]->dims();
xft_ffn2_w.emplace_back(
const_cast<TW*>(ffn2_weight[i]->data<TW>()),
const_cast<float*>(ffn2_weight_max[i]->data<float>()),
std::array<int64_t, 2>{ffn2w_dims[0], ffn2w_dims[1]});
xft_ffn2_bias.emplace_back(const_cast<float*>(ffn2_bias[i]->data<float>()),
std::array<int64_t, 1>{ffn2_bias[i]->dims()[0]});
// cache kv in
if (time_step_value > 0) {
auto cachekv_dims = cache_kv.get_ptr()->at(i)->dims();
xft_cache_kv.emplace_back(reinterpret_cast<XPUTypeT*>(const_cast<T*>(
cache_kv.get_ptr()->at(i)->data<T>())),
std::array<int64_t, 5>{cachekv_dims[0],
cachekv_dims[1],
cachekv_dims[2],
cachekv_dims[3],
cachekv_dims[4]});
}
// cache kv out
auto cachekv_out_dims = cache_kv_out[i]->dims();
xft_cache_kv_out.emplace_back(
reinterpret_cast<XPUTypeT*>(ctx.template Alloc<T>(cache_kv_out[i])),
std::array<int64_t, 5>{cachekv_out_dims[0],
cachekv_out_dims[1],
cachekv_out_dims[2],
cachekv_out_dims[3],
cachekv_out_dims[4]});
}
xft::NlpParam param;
param.num_layer = layers;
param.n_head = num_head;
param.size_per_head = dim_head;
param.hidden_act = act_method;
param.is_fuse_qkv = true;
r = xft::fused_multi_transformer<XPUTypeT, TW, int16_t>(ctx.x_context(),
xft_x,
xft_cache_kv,
xft_src_mask,
xft_ln_scale,
xft_ln_bias,
xft_qkvw,
xft_qkv_bias,
xft_out_linear_w,
xft_out_linear_bias,
xft_ffn_ln_scale,
xft_ffn_ln_bias,
xft_ffn1_w,
xft_ffn1_bias,
xft_ffn2_w,
xft_ffn2_bias,
param,
time_step_value,
&xft_out,
xft_cache_kv_out);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xft::fused_multi_transformer");
#else
LOG(FATAL) << "fused_multi_transformer_xpu is not supported since it's not "
"compiled with XPU_XFT";
#endif
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_multi_transformer_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::FusedMultiTransformerXpuKernel,
float,
phi::dtype::float16) {
kernel->InputAt(20).SetBackend(phi::Backend::CPU);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册