未验证 提交 2039115c 编写于 作者: S shentanyue 提交者: GitHub

[XPU] Fusion of gather and assign operators to fused_mt op for reducing memory usage (#53262)

上级 d27f15ed
......@@ -242,7 +242,7 @@ if(WITH_XPU)
pass_library(one_beam_size_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(delete_isolated_node_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS
pass_library(fused_multi_transformer_xpu_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(fused_multi_transformer_cachekv_layout_trans_pass inference DIR
......@@ -519,9 +519,9 @@ if(WITH_XPU)
SRCS xpu/delete_isolated_node_pass_test.cc
DEPS delete_isolated_node_pass)
cc_test(
test_fused_multi_transformer_xpu_quant_pass
SRCS xpu/fused_multi_transformer_xpu_quant_pass_tester.cc
DEPS fused_multi_transformer_xpu_quant_pass)
test_fused_multi_transformer_xpu_pass
SRCS xpu/fused_multi_transformer_xpu_pass_tester.cc
DEPS fused_multi_transformer_xpu_pass)
cc_test(
test_one_beam_size_fuse_pass
SRCS xpu/one_beam_size_fuse_pass_test.cc
......
......@@ -65,7 +65,7 @@ static const std::vector<std::string> xpu_support_subgraph_passes = {
"fused_multi_transformer_cachekv_layout_trans_pass",
"one_beam_size_fuse_pass",
"stack_fuse_pass",
"fused_multi_transformer_xpu_quant_pass",
"fused_multi_transformer_xpu_pass",
"fc_xpu_fuse_pass",
"link_xpu_op_max_pass",
};
......
......@@ -39,6 +39,32 @@ namespace framework {
namespace ir {
namespace patterns {
struct FusedMultiTransformerAssignPattern : public PatternBase {
FusedMultiTransformerAssignPattern(PDPattern* pattern,
const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(assign);
// declare variable node's name
PATTERN_DECL_NODE(assign_out);
};
FusedMultiTransformerAssignPattern::FusedMultiTransformerAssignPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* assign =
pattern->NewNode(assign_repr())
->assert_is_op("assign")
->assert_more([&](Node* node) {
auto pre_op_nodes = node->inputs[0]->inputs;
return pre_op_nodes.size() == 1 &&
pre_op_nodes[0]->Op()->Type() == "fused_multi_transformer";
});
auto* assign_out =
pattern->NewNode(assign_out_repr())->assert_is_op_output("assign", "Out");
assign->LinksTo({assign_out});
}
struct FusedMultiTransformerPattern : public PatternBase {
FusedMultiTransformerPattern(PDPattern* pattern,
const std::string& name_scope,
......@@ -47,7 +73,6 @@ struct FusedMultiTransformerPattern : public PatternBase {
bool with_time_step,
bool with_seq_lengths,
bool with_src_mask);
// declare operator node's name
PATTERN_DECL_NODE(fused_mt);
// declare variable node's name
......@@ -234,39 +259,101 @@ FusedMultiTransformerPattern::FusedMultiTransformerPattern(
} // namespace patterns
/*
1. transpose and quantify the weights of fused_multi_transformer op from fp32 to
1. Remove gather and assign op to reduce graphics memory consumption
2. transpose and quantify the weights of fused_multi_transformer op from fp32 to
int16
*/
class FusedMultiTransformerXPUQuantPass : public FusePassBase {
class FusedMultiTransformerXPUPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
int ApplyImpl(ir::Graph* graph,
/*
Origin subgraph:
fused_multi_transformer
| | |
assign assign ...
| | |
gather gather ...
Fused subgraph:
fused_multi_transformer
*/
void RemoveAssignGather(ir::Graph* graph) const;
/*
Origin subgraph:
fused_multi_transformer
Fused subgraph:
fused_multi_transformer_xpu
*/
int FusedMultiTransformerXPUQuant(ir::Graph* graph,
bool with_pre_caches,
bool with_rotary_pos_emb,
bool with_time_step,
bool with_seq_lengths,
bool with_src_mask) const;
const std::string name_scope_{"fused_multi_transformer_xpu_quant_pass"};
const std::string name_scope_{"fused_multi_transformer_xpu_pass"};
};
void FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph) const {
void FusedMultiTransformerXPUPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
VLOG(3) << "in FusedMultiTransformerXPUQuantPass::ApplyImpl";
VLOG(3) << "in FusedMultiTransformerXPUPass::ApplyImpl";
int found_subgraph_count = 0;
RemoveAssignGather(graph);
for (bool with_time_step : {true, false}) {
found_subgraph_count +=
ApplyImpl(graph, false, false, with_time_step, false, true);
found_subgraph_count += FusedMultiTransformerXPUQuant(
graph, false, false, with_time_step, false, true);
}
AddStatis(found_subgraph_count);
}
int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
void FusedMultiTransformerXPUPass::RemoveAssignGather(ir::Graph* graph) const {
// detect assign + gather
GraphPatternDetector gpd;
patterns::FusedMultiTransformerAssignPattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(1) << "handle RemoveAssignGather";
GET_IR_NODE(assign);
GET_IR_NODE(assign_out);
// Assign_out may not link to gather, so we find gather by input name.
auto next_ops = FindOpNodeByInputName(graph, assign_out->Name());
if (next_ops.size() != 1 || next_ops[0]->Name() != "gather") return;
auto* gather = next_ops[0];
// "assign_out" is used in multi blocks. "assign_out" should be reserved.
auto* gather_index = gather->inputs[0];
auto* assign_in = assign->inputs[0];
auto* fused_multi_transformer = assign_in->inputs[0];
fused_multi_transformer->Op()->Rename(assign_in->Name(),
assign_out->Name());
fused_multi_transformer->Op()->SetInput("gather_index",
gather->Op()->Input("Index"));
fused_multi_transformer->Op()->SetAttr("gather_axis",
gather->Op()->GetAttr("axis"));
IR_NODE_LINK_TO(gather_index, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, assign_out);
std::unordered_set<const Node*> delete_nodes{assign, assign_in, gather};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
int FusedMultiTransformerXPUPass::FusedMultiTransformerXPUQuant(
ir::Graph* graph,
bool with_pre_caches,
bool with_rotary_pos_emb,
bool with_time_step,
......@@ -286,7 +373,7 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle FusedMultiTransformerXPUQuantPass fuse";
VLOG(4) << "handle FusedMultiTransformerXPUQuant";
GET_IR_NODE(x);
GET_IR_NODE(ln_scale);
......@@ -459,6 +546,13 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
if (name_caches.count("CacheKV") > 0) {
fused_mt_xpu_op_desc->SetInput("cache_kv", name_caches.at("CacheKV"));
}
if (name_caches.count("gather_index") > 0) {
fused_mt_xpu_op_desc->SetInput("gather_index",
name_caches.at("gather_index"));
}
if (!fused_mt_xpu_op_desc->HasAttr("gather_axis")) {
fused_mt_xpu_op_desc->SetAttr("gather_axis", 0);
}
if (pre_caches) {
fused_mt_xpu_op_desc->SetInput("pre_caches", name_caches.at("PreCaches"));
}
......@@ -529,5 +623,5 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
} // namespace framework
} // namespace paddle
REGISTER_PASS(fused_multi_transformer_xpu_quant_pass,
paddle::framework::ir::FusedMultiTransformerXPUQuantPass);
REGISTER_PASS(fused_multi_transformer_xpu_pass,
paddle::framework::ir::FusedMultiTransformerXPUPass);
......@@ -64,7 +64,62 @@ Scope* CreateParamScope() {
return param_scope;
}
TEST(FusedMultiTransformerXPUQuantPass, context_stage) {
VarDesc* Data(paddle::framework::BlockDesc* block,
std::string name,
std::vector<int64_t> shape = {},
bool is_persistable = false,
proto::VarType::Type data_type = proto::VarType::FP32) {
auto* var = block->Var(name);
var->SetType(proto::VarType::LOD_TENSOR);
var->SetDataType(data_type);
var->SetShape(shape);
var->SetPersistable(is_persistable);
return var;
}
TEST(RemoveAssignGather, basic) {
paddle::framework::ProgramDesc program;
auto* block = program.MutableBlock(0);
auto* x = Data(block, "fused_multi_transformer_x", {1, 1, 1536});
auto* cache_kv =
Data(block, "fused_multi_transformer_cache_kv", {2, 1, 24, 512, 64});
OpDesc* fused_multi_transformer_op = block->AppendOp();
fused_multi_transformer_op->SetType("fused_multi_transformer");
fused_multi_transformer_op->SetInput("X", {x->Name()});
fused_multi_transformer_op->SetInput("CacheKV", {cache_kv->Name()});
fused_multi_transformer_op->SetOutput("CacheKVOut", {cache_kv->Name()});
auto* assign_out = Data(block, "assign_out", cache_kv->GetShape());
OpDesc* assign_op = block->AppendOp();
assign_op->SetType("assign");
assign_op->SetInput("X", {cache_kv->Name()});
assign_op->SetOutput("Out", {assign_out->Name()});
OpDesc* gather_op = block->AppendOp();
auto gather_index = Data(block, "gather_index", {10});
gather_op->SetType("gather");
gather_op->SetInput("X", {assign_out->Name()});
gather_op->SetInput("Index", {gather_index->Name()});
gather_op->SetAttr("axis", {1});
gather_op->SetOutput("Out", {cache_kv->Name()});
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto pass = PassRegistry::Instance().Get("fused_multi_transformer_xpu_pass");
pass->Apply(graph.get());
auto assign_num = GetNumOpNodes(graph, "assign");
auto gather_num = GetNumOpNodes(graph, "gather");
PADDLE_ENFORCE_EQ(assign_num,
0,
platform::errors::PreconditionNotMet(
"assign op should be removed from the graph."));
PADDLE_ENFORCE_EQ(gather_num,
0,
platform::errors::PreconditionNotMet(
"gather op should be removed from the graph."));
}
TEST(FusedMultiTransformerXPUPass, context_stage) {
DEF_INPUT_DATA
auto* cache_kv = layers.fill_constant_batch_size_like(
......@@ -95,10 +150,9 @@ TEST(FusedMultiTransformerXPUQuantPass, context_stage) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass =
PassRegistry::Instance().Get("fused_multi_transformer_xpu_quant_pass");
auto pass = PassRegistry::Instance().Get("fused_multi_transformer_xpu_pass");
if (pass.get() == nullptr) {
LOG(INFO) << "get fused_multi_transformer_xpu_quant_pass failed";
LOG(INFO) << "get fused_multi_transformer_xpu_pass failed";
}
graph.reset(pass->Apply(graph.release()));
......@@ -114,7 +168,7 @@ TEST(FusedMultiTransformerXPUQuantPass, context_stage) {
num_nodes_after));
}
TEST(FusedMultiTransformerXPUQuantPass, decoder_stage) {
TEST(FusedMultiTransformerXPUPass, decoder_stage) {
DEF_INPUT_DATA
auto* cache_kv = layers.fill_constant_batch_size_like(
......@@ -146,10 +200,9 @@ TEST(FusedMultiTransformerXPUQuantPass, decoder_stage) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass =
PassRegistry::Instance().Get("fused_multi_transformer_xpu_quant_pass");
auto pass = PassRegistry::Instance().Get("fused_multi_transformer_xpu_pass");
if (pass.get() == nullptr) {
LOG(INFO) << "get fused_multi_transformer_xpu_quant_pass failed";
LOG(INFO) << "get fused_multi_transformer_xpu_pass failed";
}
graph.reset(pass->Apply(graph.release()));
......@@ -169,4 +222,4 @@ TEST(FusedMultiTransformerXPUQuantPass, decoder_stage) {
} // namespace framework
} // namespace paddle
USE_PASS(fused_multi_transformer_xpu_quant_pass);
USE_PASS(fused_multi_transformer_xpu_pass);
......@@ -259,23 +259,6 @@ bool OnlyOneBeamSearchAndOneBeamSize(ir::Graph* graph) {
beam_search_nodes[0]->Op()->GetAttrIfExists<int>("beam_size") == 1;
}
std::vector<Node*> FindOpNodeByInputName(Graph* graph,
const std::string& var_name) {
std::vector<Node*> ret;
for (auto* node : graph->Nodes()) {
if (!node->IsOp()) continue;
auto inputs = node->Op()->Inputs();
for (auto input : inputs) {
auto in_names = input.second;
if (std::count(in_names.begin(), in_names.end(), var_name) > 0) {
ret.push_back(node);
break;
}
}
}
return ret;
}
void OneBeamSizeFusePass::RemoveAssignGather(ir::Graph* graph) const {
// detect assign + gather
GraphPatternDetector gpd;
......
......@@ -71,6 +71,23 @@ Node* FindNodeWithName(Graph* graph, std::string name) {
return nullptr;
}
std::vector<Node*> FindOpNodeByInputName(Graph* graph,
const std::string& var_name) {
std::vector<Node*> ret;
for (auto* node : graph->Nodes()) {
if (!node->IsOp()) continue;
auto inputs = node->Op()->Inputs();
for (auto input : inputs) {
auto in_names = input.second;
if (std::count(in_names.begin(), in_names.end(), var_name) > 0) {
ret.push_back(node);
break;
}
}
}
return ret;
}
template <typename T>
std::string IntTypeToString() {
LOG(FATAL) << "Not support type.";
......
......@@ -51,6 +51,9 @@ int ConvertActivationType(std::string act_type);
Node* FindNodeWithName(Graph* graph, std::string name);
std::vector<Node*> FindOpNodeByInputName(Graph* graph,
const std::string& var_name);
template <typename T>
size_t HashTensor(const phi::DenseTensor& in);
......
......@@ -523,7 +523,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"one_beam_size_fuse_pass",
"delete_cast_op_pass",
"stack_fuse_pass",
"fused_multi_transformer_xpu_quant_pass",
"fused_multi_transformer_xpu_pass",
"fc_xpu_fuse_pass",
"conv2d_xpu_fuse_pass",
"link_xpu_op_max_pass",
......
......@@ -58,14 +58,14 @@
support_dygraph_mode : true
- op : fused_multi_transformer_xpu
args : (Tensor x, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] qkvw, Tensor[] qkvw_max, Tensor[] qkv_bias, Tensor[] out_linear_w, Tensor[] out_linear_wmax, Tensor[] out_linear_bias, Tensor[] ffn_ln_scale, Tensor[] ffn_ln_bias, Tensor[] ffn1_weight, Tensor[] ffn1_weight_max, Tensor[] ffn1_bias, Tensor[] ffn2_weight, Tensor[] ffn2_weight_max, Tensor[] ffn2_bias, Tensor[] cache_kv, Tensor[] pre_caches, Tensor rotary_pos_emb, Tensor time_step, Tensor seq_lengths, Tensor src_mask, bool pre_layer_norm, int rotary_emb_dims, float epsilon, float dropout_rate, bool is_test, str dropout_implementation, str act_method, bool trans_qkvw, int ring_id)
args : (Tensor x, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] qkvw, Tensor[] qkvw_max, Tensor[] qkv_bias, Tensor[] out_linear_w, Tensor[] out_linear_wmax, Tensor[] out_linear_bias, Tensor[] ffn_ln_scale, Tensor[] ffn_ln_bias, Tensor[] ffn1_weight, Tensor[] ffn1_weight_max, Tensor[] ffn1_bias, Tensor[] ffn2_weight, Tensor[] ffn2_weight_max, Tensor[] ffn2_bias, Tensor[] cache_kv, Tensor[] pre_caches, Tensor rotary_pos_emb, Tensor time_step, Tensor seq_lengths, Tensor src_mask, Tensor gather_index, bool pre_layer_norm, int rotary_emb_dims, float epsilon, float dropout_rate, bool is_test, str dropout_implementation, str act_method, bool trans_qkvw, int ring_id, int gather_axis)
output : Tensor(out), Tensor[](cache_kv_out){out_linear_w.size()}
infer_meta :
func : FusedMultiTransformerXpuInferMeta
kernel :
func : fused_multi_transformer_xpu
data_type : x
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index
- op : generate_sequence_xpu
args : (Tensor x, DataType dtype)
......
......@@ -278,6 +278,7 @@ void FusedMultiTransformerXpuInferMeta(
const std::vector<const MetaTensor*>& time_step,
const std::vector<const MetaTensor*>& seq_lengths,
const std::vector<const MetaTensor*>& src_mask,
const std::vector<const MetaTensor*>& gather_index,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
......@@ -287,6 +288,7 @@ void FusedMultiTransformerXpuInferMeta(
const std::string& act_method,
bool trans_qkvw,
int ring_id,
int gather_axis,
MetaTensor* out,
std::vector<MetaTensor*> cache_kv_out) {
auto x_dim = x.dims();
......@@ -325,13 +327,6 @@ void FusedMultiTransformerXpuInferMeta(
phi::errors::InvalidArgument(
"The first dim of CacheKV must be 2, but got %d",
c_dim[0])); // 2
PADDLE_ENFORCE_EQ(
c_dim[2],
x_dim[0],
phi::errors::InvalidArgument("The third dim of CacheKV must be equal "
"with batch size %d, but got %d",
x_dim[0],
c_dim[2])); // batch_size
PADDLE_ENFORCE_EQ(
c_dim[3],
trans_qkvw ? y_dim[1] : y_dim[2],
......
......@@ -108,6 +108,7 @@ void FusedMultiTransformerXpuInferMeta(
const std::vector<const MetaTensor*>& time_step,
const std::vector<const MetaTensor*>& seq_lengths,
const std::vector<const MetaTensor*>& src_mask,
const std::vector<const MetaTensor*>& gather_index,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
......@@ -117,6 +118,7 @@ void FusedMultiTransformerXpuInferMeta(
const std::string& act_method,
bool trans_qkvw,
int ring_id,
int gather_axis,
MetaTensor* out,
std::vector<MetaTensor*> cache_kv_out);
} // namespace phi
......@@ -17,6 +17,8 @@
#include "glog/logging.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/kernels/memcpy_kernel.h"
#ifdef PADDLE_WITH_XPU_XFT
#include "models/fused_multi_transformer_op.h"
......@@ -52,6 +54,7 @@ void FusedMultiTransformerXpuKernel(
const paddle::optional<DenseTensor>& time_step,
const paddle::optional<DenseTensor>& seq_lengths,
const paddle::optional<DenseTensor>& src_mask,
const paddle::optional<DenseTensor>& gather_index,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
......@@ -61,6 +64,7 @@ void FusedMultiTransformerXpuKernel(
const std::string& act_method,
bool trans_qkvw,
int ring_id,
int gather_axis,
DenseTensor* out,
std::vector<DenseTensor*> cache_kv_out) {
#ifdef PADDLE_WITH_XPU_XFT
......@@ -160,6 +164,21 @@ void FusedMultiTransformerXpuKernel(
std::vector<xft::xftTensor<XPUTypeT, 5>> xft_cache_kv;
std::vector<xft::xftTensor<XPUTypeT, 5>> xft_cache_kv_out;
// Create a temporary Tensor to store the gather output of cache_kv
auto gather_index_t = gather_index.get_ptr();
auto cache_kv_dims = cache_kv.get_ptr()->at(0)->dims();
auto cache_kv_gather_dims = cache_kv_dims;
phi::DenseTensor cache_kv_gather_tensor;
if (gather_index_t) {
MetaTensor cache_kv_gather_meta(&cache_kv_gather_tensor);
phi::GatherInferMeta(*cache_kv.get_ptr()->at(0),
*gather_index_t,
Scalar(gather_axis),
&cache_kv_gather_meta);
cache_kv_gather_dims = cache_kv_gather_meta.dims();
ctx.template Alloc<T>(&cache_kv_gather_tensor);
}
int layers = qkvw.size();
for (int i = 0; i < layers; ++i) {
// step1. layer_norm
......@@ -211,27 +230,55 @@ void FusedMultiTransformerXpuKernel(
xft_ffn2_bias.emplace_back(const_cast<float*>(ffn2_bias[i]->data<float>()),
std::array<int64_t, 1>{ffn2_bias[i]->dims()[0]});
// cache kv in
if (time_step_value > 0) {
auto cachekv_dims = cache_kv.get_ptr()->at(i)->dims();
xft_cache_kv.emplace_back(reinterpret_cast<XPUTypeT*>(const_cast<T*>(
cache_kv.get_ptr()->at(i)->data<T>())),
std::array<int64_t, 5>{cachekv_dims[0],
cachekv_dims[1],
cachekv_dims[2],
cachekv_dims[3],
cachekv_dims[4]});
auto cache_kv_data = reinterpret_cast<XPUTypeT*>(
const_cast<T*>(cache_kv.get_ptr()->at(i)->data<T>()));
if (gather_index_t) {
const auto& index_type = gather_index_t->dtype();
if (index_type == DataType::INT32) {
r = xpu::gather<XPUTypeT, int32_t>(
ctx.x_context(),
cache_kv_data,
gather_index_t->data<int32_t>(),
reinterpret_cast<XPUTypeT*>(cache_kv_gather_tensor.data<T>()),
phi::vectorize<int32_t>(cache_kv_dims),
gather_index_t->dims().size() == 0 ? 1 : gather_index_t->dims()[0],
gather_axis);
} else {
r = xpu::gather<XPUTypeT, int64_t>(
ctx.x_context(),
cache_kv_data,
gather_index_t->data<int64_t>(),
reinterpret_cast<XPUTypeT*>(cache_kv_gather_tensor.data<T>()),
phi::vectorize<int32_t>(cache_kv_dims),
gather_index_t->dims().size() == 0 ? 1 : gather_index_t->dims()[0],
gather_axis);
}
// cache kv out
auto cachekv_out_dims = cache_kv_out[i]->dims();
xft_cache_kv_out.emplace_back(
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu::gather");
cache_kv_out[i]->ResizeAndAllocate(cache_kv_gather_dims);
r = xpu::copy<XPUTypeT>(
ctx.x_context(),
reinterpret_cast<XPUTypeT*>(cache_kv_gather_tensor.data<T>()),
reinterpret_cast<XPUTypeT*>(ctx.template Alloc<T>(cache_kv_out[i])),
std::array<int64_t, 5>{cachekv_out_dims[0],
cachekv_out_dims[1],
cachekv_out_dims[2],
cachekv_out_dims[3],
cachekv_out_dims[4]});
cache_kv_out[i]->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu::copy");
}
cache_kv_data = reinterpret_cast<XPUTypeT*>(
const_cast<T*>(cache_kv.get_ptr()->at(i)->data<T>()));
xft_cache_kv.emplace_back(cache_kv_data,
std::array<int64_t, 5>{cache_kv_gather_dims[0],
cache_kv_gather_dims[1],
cache_kv_gather_dims[2],
cache_kv_gather_dims[3],
cache_kv_gather_dims[4]});
// cache kv out direct use cache_kv_data
xft_cache_kv_out.emplace_back(
cache_kv_data,
std::array<int64_t, 5>{cache_kv_gather_dims[0],
cache_kv_gather_dims[1],
cache_kv_gather_dims[2],
cache_kv_gather_dims[3],
cache_kv_gather_dims[4]});
}
xft::NlpParam param;
param.num_layer = layers;
param.n_head = num_head;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册