未验证 提交 52e1742f 编写于 作者: M mayang002 提交者: GitHub

[xpu] fused_multi_transformer_xpu pass&kernel support (#51571)

上级 c36e3fd2
......@@ -142,6 +142,8 @@ if(WITH_XPU_XFT)
message(STATUS "Compile with XPU XFT!")
add_definitions(-DPADDLE_WITH_XPU_XFT)
set(XPU_XFT_INC_DIR "${XPU_INC_DIR}/xft")
include_directories(${XPU_XFT_INC_DIR})
set(XPU_XFT_LIB "${XPU_LIB_DIR}/${XPU_XFT_LIB_NAME}")
endif()
......
......@@ -235,6 +235,8 @@ if(WITH_XPU)
pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(delete_isolated_node_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif()
cc_library(
......@@ -493,4 +495,8 @@ if(WITH_XPU)
test_delete_isolated_node_pass
SRCS xpu/delete_isolated_node_pass_test.cc
DEPS delete_isolated_node_pass)
cc_test(
test_fused_multi_transformer_xpu_quant_pass
SRCS xpu/fused_multi_transformer_xpu_quant_pass_tester.cc
DEPS fused_multi_transformer_xpu_quant_pass)
endif()
......@@ -75,7 +75,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) {
1,
{2, -1, 16, 1024, 64},
0);
auto* out = layers.fused_multi_transformer(x,
auto outs = layers.fused_multi_transformer(x,
cache_kv,
src_mask,
qkv_w,
......@@ -93,7 +93,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) {
0.1,
1e-12);
x = out;
x = outs[0];
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
......@@ -126,7 +126,7 @@ TEST(FuseMultiTransformerLayerPass, decoder_fp) {
for (int i = 0; i < num_layers; ++i) {
auto* shape_out = layers.shape(src_mask);
auto* time_stamp = layers.slice(shape_out, {0}, {3}, {4});
auto* out = layers.fused_multi_transformer(x,
auto outs = layers.fused_multi_transformer(x,
cache_kv,
src_mask,
qkv_w,
......@@ -145,7 +145,7 @@ TEST(FuseMultiTransformerLayerPass, decoder_fp) {
1e-12,
time_stamp);
x = out;
x = outs[0];
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto param_scope = CreateParamScope();
......
......@@ -151,6 +151,15 @@ class Node {
var_desc_->SetName(new_name);
}
void RenameOp(const std::string& new_name) {
PADDLE_ENFORCE_EQ(
type_ == Type::kOperation && op_desc_,
true,
platform::errors::InvalidArgument("Node must be type of variable."));
name_ = new_name;
op_desc_->SetType(new_name);
}
int DescOrder() const { return desc_order_; }
int GetVarNodeBlockId() const {
......
......@@ -49,6 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = {
"fuse_multi_transformer_layer_pass",
"delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass",
"fused_multi_transformer_xpu_quant_pass",
"fc_xpu_fuse_pass",
"delete_op_device_pass"};
......
......@@ -571,33 +571,35 @@ struct Layers {
return out;
}
VarDesc* fused_multi_transformer(VarDesc* x,
VarDesc* cache_kv,
VarDesc* src_mask,
VarDesc* qkv_w,
VarDesc* qkv_bias,
VarDesc* out_linear_w,
VarDesc* out_linear_bias,
VarDesc* ffn1_w,
VarDesc* ffn1_bias,
VarDesc* ffn2_w,
VarDesc* ffn2_bias,
VarDesc* ln_scale,
VarDesc* ln_bias,
VarDesc* ffn_ln_scale,
VarDesc* ffn_ln_bias,
float epsilon,
float dropout_rate,
VarDesc* time_stamp = nullptr,
VarDesc* qkv_out_scale = nullptr,
VarDesc* out_linear_out_scale = nullptr,
VarDesc* ffn1_out_scale = nullptr,
VarDesc* ffn2_out_scale = nullptr,
std::vector<float> qkv_in_scale = {},
std::vector<float> out_linear_in_scale = {},
std::vector<float> ffn1_in_scale = {},
std::vector<float> ffn2_in_scale = {}) {
std::vector<VarDesc*> fused_multi_transformer(
VarDesc* x,
VarDesc* cache_kv,
VarDesc* src_mask,
VarDesc* qkv_w,
VarDesc* qkv_bias,
VarDesc* out_linear_w,
VarDesc* out_linear_bias,
VarDesc* ffn1_w,
VarDesc* ffn1_bias,
VarDesc* ffn2_w,
VarDesc* ffn2_bias,
VarDesc* ln_scale,
VarDesc* ln_bias,
VarDesc* ffn_ln_scale,
VarDesc* ffn_ln_bias,
float epsilon,
float dropout_rate,
VarDesc* time_stamp = nullptr,
VarDesc* qkv_out_scale = nullptr,
VarDesc* out_linear_out_scale = nullptr,
VarDesc* ffn1_out_scale = nullptr,
VarDesc* ffn2_out_scale = nullptr,
std::vector<float> qkv_in_scale = {},
std::vector<float> out_linear_in_scale = {},
std::vector<float> ffn1_in_scale = {},
std::vector<float> ffn2_in_scale = {}) {
VarDesc* out = lod_tensor(unique_name());
VarDesc* cache_kv_out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
std::string op_type = qkv_out_scale ? "fused_multi_transformer_int8"
: "fused_multi_transformer";
......@@ -623,6 +625,7 @@ struct Layers {
op->SetAttr("dropout_rate", dropout_rate);
op->SetAttr("epsilon", epsilon);
op->SetOutput("Out", {out->Name()});
op->SetOutput("CacheKVOut", {cache_kv_out->Name()});
if (time_stamp) {
op->SetInput("TimeStep", {time_stamp->Name()});
......@@ -638,7 +641,8 @@ struct Layers {
op->SetAttr("ffn1_in_scale", ffn1_in_scale);
op->SetAttr("ffn2_in_scale", ffn2_in_scale);
}
return out;
std::vector<VarDesc*> outs = {out, cache_kv_out};
return outs;
}
VarDesc* dequantize_linear(VarDesc* x,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#define DEF_INPUT_DATA \
Layers layers; \
auto* x = layers.data("x", {1, 128, 1024}); \
auto* src_mask = layers.data("src_mask", {1, 16, 128, 128}); \
auto* ln_scale = layers.data("ln_scale", {1024}, true); \
auto* ln_bias = layers.data("ln_bias", {1024}, true); \
auto* qkv_w = layers.data("qkv_w", {3, 16, 64, 1024}, true); \
auto* qkv_bias = layers.data("qkv_bias", {3, 16, 64}, true); \
auto* out_linear_w = layers.data("out_linear_w", {1024, 1024}, true); \
auto* out_linear_bias = layers.data("out_linear_bias", {1024}, true); \
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); \
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); \
auto* ffn1_w = layers.data("ffn1_w", {1024, 4096}, true); \
auto* ffn1_bias = layers.data("ffn1_bias", {4096}, true); \
auto* ffn2_w = layers.data("ffn2_w", {4096, 1024}, true); \
auto* ffn2_bias = layers.data("ffn2_bias", {1024}, true);
namespace paddle {
namespace framework {
namespace ir {
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "ln_scale", {1024});
AddVarToScope(param_scope, "ln_bias", {1024});
AddVarToScope(param_scope, "ffn_ln_scale", {1024});
AddVarToScope(param_scope, "ffn_ln_bias", {1024});
AddVarToScope(param_scope, "qkv_w", {3, 16, 64, 1024});
AddVarToScope(param_scope, "out_linear_w", {1024, 1024});
AddVarToScope(param_scope, "ffn1_w", {1024, 4096});
AddVarToScope(param_scope, "ffn2_w", {4096, 1024});
AddVarToScope(param_scope, "qkv_bias", {3072});
AddVarToScope(param_scope, "out_linear_bias", {1024});
AddVarToScope(param_scope, "ffn1_bias", {4096});
AddVarToScope(param_scope, "ffn2_bias", {1024});
return param_scope;
}
TEST(FusedMultiTransformerXPUQuantPass, context_stage) {
DEF_INPUT_DATA
auto* cache_kv = layers.fill_constant_batch_size_like(
x,
static_cast<int>(proto::VarType::FP32),
0,
1,
{2, -1, 16, 1024, 64},
0);
layers.fused_multi_transformer(x,
cache_kv,
src_mask,
qkv_w,
qkv_bias,
out_linear_w,
out_linear_bias,
ffn1_w,
ffn1_bias,
ffn2_w,
ffn2_bias,
ln_scale,
ln_bias,
ffn_ln_scale,
ffn_ln_bias,
0.1,
1e-12);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass =
PassRegistry::Instance().Get("fused_multi_transformer_xpu_quant_pass");
if (pass.get() == nullptr) {
LOG(INFO) << "get fused_multi_transformer_xpu_quant_pass failed";
}
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer_xpu");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(
num_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fuse_multi_transformer_layer_pass, "
"The node num in graph should be 1, but the result is %d",
num_nodes_after));
}
TEST(FusedMultiTransformerXPUQuantPass, decoder_stage) {
DEF_INPUT_DATA
auto* cache_kv = layers.fill_constant_batch_size_like(
x,
static_cast<int>(proto::VarType::FP32),
0,
1,
{2, -1, 16, 1024, 64},
0);
auto* time_step = layers.data("time_step", {1});
layers.fused_multi_transformer(x,
cache_kv,
src_mask,
qkv_w,
qkv_bias,
out_linear_w,
out_linear_bias,
ffn1_w,
ffn1_bias,
ffn2_w,
ffn2_bias,
ln_scale,
ln_bias,
ffn_ln_scale,
ffn_ln_bias,
0.1,
1e-12,
time_step);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass =
PassRegistry::Instance().Get("fused_multi_transformer_xpu_quant_pass");
if (pass.get() == nullptr) {
LOG(INFO) << "get fused_multi_transformer_xpu_quant_pass failed";
}
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer_xpu");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(
num_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fuse_multi_transformer_layer_pass, "
"The node num in graph should be 1, but the result is %d",
num_nodes_after));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fused_multi_transformer_xpu_quant_pass);
......@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"embedding_with_eltwise_add_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
"fused_multi_transformer_xpu_quant_pass",
"fc_xpu_fuse_pass",
"link_xpu_op_max_pass",
"delete_op_device_pass",
......
......@@ -47,6 +47,16 @@
param : [x, axis, keepdim, reduce_all]
backward : frobenius_norm_grad
- op : fused_multi_transformer_xpu
args : (Tensor x, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] qkvw, Tensor[] qkvw_max, Tensor[] qkv_bias, Tensor[] out_linear_w, Tensor[] out_linear_wmax, Tensor[] out_linear_bias, Tensor[] ffn_ln_scale, Tensor[] ffn_ln_bias, Tensor[] ffn1_weight, Tensor[] ffn1_weight_max, Tensor[] ffn1_bias, Tensor[] ffn2_weight, Tensor[] ffn2_weight_max, Tensor[] ffn2_bias, Tensor[] cache_kv, Tensor[] pre_caches, Tensor rotary_pos_emb, Tensor time_step, Tensor seq_lengths, Tensor src_mask, bool pre_layer_norm, int rotary_emb_dims, float epsilon, float dropout_rate, bool is_test, str dropout_implementation, str act_method, bool trans_qkvw, int ring_id)
output : Tensor(out), Tensor[](cache_kv_out){out_linear_w.size()}
infer_meta :
func : FusedMultiTransformerXpuInferMeta
kernel :
func : fused_multi_transformer_xpu
data_type : x
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask
- op : generate_sequence_xpu
args : (Tensor x, DataType dtype)
output : Tensor
......
......@@ -331,6 +331,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"fused_multi_transformer_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"unfold",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"unfold_grad",
......
......@@ -114,4 +114,108 @@ void MultiEncoderXPUInferMeta(
}
}
void FusedMultiTransformerXpuInferMeta(
const MetaTensor& x,
const std::vector<const MetaTensor*>& ln_scale,
const std::vector<const MetaTensor*>& ln_bias,
const std::vector<const MetaTensor*>& qkvw,
const std::vector<const MetaTensor*>& qkvw_max,
const std::vector<const MetaTensor*>& qkv_bias,
const std::vector<const MetaTensor*>& out_linear_w,
const std::vector<const MetaTensor*>& out_linear_wmax,
const std::vector<const MetaTensor*>& out_linear_bias,
const std::vector<const MetaTensor*>& ffn_ln_scale,
const std::vector<const MetaTensor*>& ffn_ln_bias,
const std::vector<const MetaTensor*>& ffn1_weight,
const std::vector<const MetaTensor*>& ffn1_weight_max,
const std::vector<const MetaTensor*>& ffn1_bias,
const std::vector<const MetaTensor*>& ffn2_weight,
const std::vector<const MetaTensor*>& ffn2_weight_max,
const std::vector<const MetaTensor*>& ffn2_bias,
const std::vector<const MetaTensor*>& cache_kv,
const std::vector<const MetaTensor*>& pre_caches,
const std::vector<const MetaTensor*>& rotary_pos_emb,
const std::vector<const MetaTensor*>& time_step,
const std::vector<const MetaTensor*>& seq_lengths,
const std::vector<const MetaTensor*>& src_mask,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
float dropout_rate,
bool is_test,
const std::string& dropout_implementation,
const std::string& act_method,
bool trans_qkvw,
int ring_id,
MetaTensor* out,
std::vector<MetaTensor*> cache_kv_out) {
auto x_dim = x.dims();
auto y_dim = qkvw[0]->dims();
PADDLE_ENFORCE_EQ(
x_dim.size(),
3,
phi::errors::InvalidArgument("The dimensions of x must be 3"
"(batch_size, seq_len, dim_embed),"
"but received dimensions of"
"Input is [%d]",
x_dim.size()));
PADDLE_ENFORCE_EQ(
y_dim.size(),
4,
phi::errors::InvalidArgument("The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"but received dimensions of"
"Input is [%d]",
y_dim.size()));
PADDLE_ENFORCE_EQ(
x_dim[2],
trans_qkvw ? y_dim[3] : y_dim[0],
phi::errors::InvalidArgument(
"ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is "
"true) or y_dim[0](trans_qkvw is false)"
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]",
x_dim,
y_dim));
if (cache_kv.size() > 0) {
const auto& c_dim = cache_kv[0]->dims();
PADDLE_ENFORCE_EQ(
c_dim.size(),
5,
phi::errors::InvalidArgument("The CacheKV must be 5 dims, but got %d",
c_dim.size()));
PADDLE_ENFORCE_EQ(c_dim[0],
2,
phi::errors::InvalidArgument(
"The first dim of CacheKV must be 2, but got %d",
c_dim[0])); // 2
PADDLE_ENFORCE_EQ(c_dim[1],
x_dim[0],
phi::errors::InvalidArgument(
"The second dim of CacheKV must be equal with "
"batch size %d, but got %d",
x_dim[0],
c_dim[1])); // batch_size
PADDLE_ENFORCE_EQ(c_dim[2],
trans_qkvw ? y_dim[1] : y_dim[2],
phi::errors::InvalidArgument(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d",
trans_qkvw ? y_dim[1] : y_dim[2],
c_dim[2])); // num_head
PADDLE_ENFORCE_EQ(c_dim[4],
trans_qkvw ? y_dim[2] : y_dim[3],
phi::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d",
trans_qkvw ? y_dim[2] : y_dim[3],
c_dim[4])); // head_size
}
out->set_dims(x_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
} // namespace phi
......@@ -66,4 +66,39 @@ void MultiEncoderXPUInferMeta(
MetaTensor* x_fp16,
MetaTensor* out_fp16);
void FusedMultiTransformerXpuInferMeta(
const MetaTensor& x,
const std::vector<const MetaTensor*>& ln_scale,
const std::vector<const MetaTensor*>& ln_bias,
const std::vector<const MetaTensor*>& qkvw,
const std::vector<const MetaTensor*>& qkvw_max,
const std::vector<const MetaTensor*>& qkv_bias,
const std::vector<const MetaTensor*>& out_linear_w,
const std::vector<const MetaTensor*>& out_linear_wmax,
const std::vector<const MetaTensor*>& out_linear_bias,
const std::vector<const MetaTensor*>& ffn_ln_scale,
const std::vector<const MetaTensor*>& ffn_ln_bias,
const std::vector<const MetaTensor*>& ffn1_weight,
const std::vector<const MetaTensor*>& ffn1_weight_max,
const std::vector<const MetaTensor*>& ffn1_bias,
const std::vector<const MetaTensor*>& ffn2_weight,
const std::vector<const MetaTensor*>& ffn2_weight_max,
const std::vector<const MetaTensor*>& ffn2_bias,
const std::vector<const MetaTensor*>& cache_kv,
const std::vector<const MetaTensor*>& pre_caches,
const std::vector<const MetaTensor*>& rotary_pos_emb,
const std::vector<const MetaTensor*>& time_step,
const std::vector<const MetaTensor*>& seq_lengths,
const std::vector<const MetaTensor*>& src_mask,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
float dropout_rate,
bool is_test,
const std::string& dropout_implementation,
const std::string& act_method,
bool trans_qkvw,
int ring_id,
MetaTensor* out,
std::vector<MetaTensor*> cache_kv_out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/memcpy_kernel.h"
#ifdef PADDLE_WITH_XPU_XFT
#include "models/fused_multi_transformer_op.h"
namespace xft = baidu::xpu::xft;
#endif
namespace phi {
namespace fusion {
template <typename T, typename Context>
void FusedMultiTransformerXpuKernel(
const Context& ctx,
const DenseTensor& xx,
const std::vector<const DenseTensor*>& ln_scale,
const std::vector<const DenseTensor*>& ln_bias,
const std::vector<const DenseTensor*>& qkvw,
const std::vector<const DenseTensor*>& qkvw_max,
const std::vector<const DenseTensor*>& qkv_bias,
const std::vector<const DenseTensor*>& out_linear_w,
const std::vector<const DenseTensor*>& out_linear_wmax,
const std::vector<const DenseTensor*>& out_linear_bias,
const std::vector<const DenseTensor*>& ffn_ln_scale,
const std::vector<const DenseTensor*>& ffn_ln_bias,
const std::vector<const DenseTensor*>& ffn1_weight,
const std::vector<const DenseTensor*>& ffn1_weight_max,
const std::vector<const DenseTensor*>& ffn1_bias,
const std::vector<const DenseTensor*>& ffn2_weight,
const std::vector<const DenseTensor*>& ffn2_weight_max,
const std::vector<const DenseTensor*>& ffn2_bias,
const paddle::optional<std::vector<const DenseTensor*>>& cache_kv,
const paddle::optional<std::vector<const DenseTensor*>>& pre_caches,
const paddle::optional<DenseTensor>& rotary_pos_emb,
const paddle::optional<DenseTensor>& time_step,
const paddle::optional<DenseTensor>& seq_lengths,
const paddle::optional<DenseTensor>& src_mask,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
float dropout_rate,
bool is_test,
const std::string& dropout_implementation,
const std::string& act_method,
bool trans_qkvw,
int ring_id,
DenseTensor* out,
std::vector<DenseTensor*> cache_kv_out) {
#ifdef PADDLE_WITH_XPU_XFT
using XPUTypeT = typename XPUTypeTrait<T>::Type;
PADDLE_ENFORCE_EQ(pre_layer_norm,
true,
phi::errors::PreconditionNotMet(
"Only support pre_layer_norm = true at now."));
PADDLE_ENFORCE_EQ(
seq_lengths.get_ptr(),
nullptr,
phi::errors::PreconditionNotMet("seq_lengths not support at now."));
PADDLE_ENFORCE_EQ(
rotary_pos_emb.get_ptr(),
nullptr,
phi::errors::PreconditionNotMet("rotary_pos_emb not support at now."));
PADDLE_ENFORCE_EQ(
pre_caches.get_ptr(),
nullptr,
phi::errors::PreconditionNotMet("pre_caches not support at now."));
PADDLE_ENFORCE_NE(
src_mask.get_ptr(),
nullptr,
phi::errors::PreconditionNotMet("src_mask should not be nullptr."));
PADDLE_ENFORCE_EQ(trans_qkvw,
true,
phi::errors::PreconditionNotMet(
"Only support trans_qkvw == true at now."));
const auto x_dims = xx.dims();
int seq_len = x_dims[1];
const auto qkv_w_dims = qkvw[0]->dims();
int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2];
int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3];
int time_step_value = -1;
if (time_step) {
PADDLE_ENFORCE_EQ(time_step.get_ptr()->place(),
phi::CPUPlace(),
phi::errors::PreconditionNotMet(
"The place of input(time_step) must be CPUPlace."));
// cache_seq_len
time_step_value = time_step.get_ptr()->data<int>()[0];
PADDLE_ENFORCE_GT(
time_step_value,
0,
phi::errors::PreconditionNotMet(
"The value of time_step must > 0, but now is %d", time_step_value));
PADDLE_ENFORCE_EQ(
seq_len,
1,
phi::errors::PreconditionNotMet(
"In decode stage, the seq_len of input must be 1, but now is %d",
seq_len));
}
XPUTypeT* x_data = reinterpret_cast<XPUTypeT*>(const_cast<T*>(xx.data<T>()));
XPUTypeT* src_mask_data = reinterpret_cast<XPUTypeT*>(
const_cast<T*>(src_mask.get_ptr()->data<T>()));
auto* out_data = reinterpret_cast<XPUTypeT*>(ctx.template Alloc<T>(out));
auto src_mask_dims = src_mask.get_ptr()->dims();
auto out_dims = out->dims();
auto xft_x = xft::xftTensor<XPUTypeT, 3>(
x_data, std::array<int64_t, 3>{x_dims[0], x_dims[1], x_dims[2]});
// TODO(mayang02): xft support mask.dtype = float16
xpu::ctx_guard RAII_GUARD(ctx.x_context());
float* src_mask_fp32_data =
RAII_GUARD.alloc<float>(src_mask.get_ptr()->numel());
int r = xpu::cast<XPUTypeT, float>(ctx.x_context(),
src_mask_data,
src_mask_fp32_data,
src_mask.get_ptr()->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu::cast");
auto xft_src_mask =
xft::xftTensor<float, 4>(src_mask_fp32_data,
std::array<int64_t, 4>{src_mask_dims[0],
src_mask_dims[1],
src_mask_dims[2],
src_mask_dims[3]});
auto xft_out = xft::xftTensor<XPUTypeT, 3>(
out_data, std::array<int64_t, 3>{out_dims[0], out_dims[1], out_dims[2]});
typedef int16_t TW;
std::vector<xft::xftVec<float>> xft_ln_scale;
std::vector<xft::xftVec<float>> xft_ln_bias;
std::vector<xft::xftMat<TW>> xft_qkvw;
std::vector<xft::xftVec<float>> xft_qkv_bias;
std::vector<xft::xftMat<TW>> xft_out_linear_w;
std::vector<xft::xftVec<float>> xft_out_linear_bias;
std::vector<xft::xftVec<float>> xft_ffn_ln_scale;
std::vector<xft::xftVec<float>> xft_ffn_ln_bias;
std::vector<xft::xftMat<TW>> xft_ffn1_w;
std::vector<xft::xftVec<float>> xft_ffn1_bias;
std::vector<xft::xftMat<TW>> xft_ffn2_w;
std::vector<xft::xftVec<float>> xft_ffn2_bias;
std::vector<xft::xftTensor<XPUTypeT, 5>> xft_cache_kv;
std::vector<xft::xftTensor<XPUTypeT, 5>> xft_cache_kv_out;
int layers = qkvw.size();
for (int i = 0; i < layers; ++i) {
// step1. layer_norm
xft_ln_scale.emplace_back(const_cast<float*>(ln_scale[i]->data<float>()),
std::array<int64_t, 1>{ln_scale[i]->dims()[0]});
xft_ln_bias.emplace_back(const_cast<float*>(ln_bias[i]->data<float>()),
std::array<int64_t, 1>{ln_bias[i]->dims()[0]});
// step2. qkv
auto qkvw_dims = qkvw[i]->dims();
xft_qkvw.emplace_back(
const_cast<TW*>(qkvw[i]->data<TW>()),
const_cast<float*>(qkvw_max[i]->data<float>()),
std::array<int64_t, 2>{qkvw_dims[0] * qkvw_dims[1] * qkvw_dims[2],
qkvw_dims[3]});
auto qkvb_dims = qkv_bias[i]->dims();
xft_qkv_bias.emplace_back(
const_cast<float*>(qkv_bias[i]->data<float>()),
std::array<int64_t, 1>{qkvb_dims[0] * qkvb_dims[1] * qkvb_dims[2]});
// attn out
auto outw_dims = out_linear_w[i]->dims();
xft_out_linear_w.emplace_back(
const_cast<TW*>(out_linear_w[i]->data<TW>()),
const_cast<float*>(out_linear_wmax[i]->data<float>()),
std::array<int64_t, 2>{outw_dims[0], outw_dims[1]});
xft_out_linear_bias.emplace_back(
const_cast<float*>(out_linear_bias[i]->data<float>()),
std::array<int64_t, 1>{out_linear_bias[i]->dims()[0]});
// ffn ln
xft_ffn_ln_scale.emplace_back(
const_cast<float*>(ffn_ln_scale[i]->data<float>()),
std::array<int64_t, 1>{ffn_ln_scale[i]->dims()[0]});
xft_ffn_ln_bias.emplace_back(
const_cast<float*>(ffn_ln_bias[i]->data<float>()),
std::array<int64_t, 1>{ffn_ln_bias[i]->dims()[0]});
// ffn1
auto ffn1w_dims = ffn1_weight[i]->dims();
xft_ffn1_w.emplace_back(
const_cast<TW*>(ffn1_weight[i]->data<TW>()),
const_cast<float*>(ffn1_weight_max[i]->data<float>()),
std::array<int64_t, 2>{ffn1w_dims[0], ffn1w_dims[1]});
xft_ffn1_bias.emplace_back(const_cast<float*>(ffn1_bias[i]->data<float>()),
std::array<int64_t, 1>{ffn1_bias[i]->dims()[0]});
// ffn2
auto ffn2w_dims = ffn2_weight[i]->dims();
xft_ffn2_w.emplace_back(
const_cast<TW*>(ffn2_weight[i]->data<TW>()),
const_cast<float*>(ffn2_weight_max[i]->data<float>()),
std::array<int64_t, 2>{ffn2w_dims[0], ffn2w_dims[1]});
xft_ffn2_bias.emplace_back(const_cast<float*>(ffn2_bias[i]->data<float>()),
std::array<int64_t, 1>{ffn2_bias[i]->dims()[0]});
// cache kv in
if (time_step_value > 0) {
auto cachekv_dims = cache_kv.get_ptr()->at(i)->dims();
xft_cache_kv.emplace_back(reinterpret_cast<XPUTypeT*>(const_cast<T*>(
cache_kv.get_ptr()->at(i)->data<T>())),
std::array<int64_t, 5>{cachekv_dims[0],
cachekv_dims[1],
cachekv_dims[2],
cachekv_dims[3],
cachekv_dims[4]});
}
// cache kv out
auto cachekv_out_dims = cache_kv_out[i]->dims();
xft_cache_kv_out.emplace_back(
reinterpret_cast<XPUTypeT*>(ctx.template Alloc<T>(cache_kv_out[i])),
std::array<int64_t, 5>{cachekv_out_dims[0],
cachekv_out_dims[1],
cachekv_out_dims[2],
cachekv_out_dims[3],
cachekv_out_dims[4]});
}
xft::NlpParam param;
param.num_layer = layers;
param.n_head = num_head;
param.size_per_head = dim_head;
param.hidden_act = act_method;
param.is_fuse_qkv = true;
r = xft::fused_multi_transformer<XPUTypeT, TW, int16_t>(ctx.x_context(),
xft_x,
xft_cache_kv,
xft_src_mask,
xft_ln_scale,
xft_ln_bias,
xft_qkvw,
xft_qkv_bias,
xft_out_linear_w,
xft_out_linear_bias,
xft_ffn_ln_scale,
xft_ffn_ln_bias,
xft_ffn1_w,
xft_ffn1_bias,
xft_ffn2_w,
xft_ffn2_bias,
param,
time_step_value,
&xft_out,
xft_cache_kv_out);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xft::fused_multi_transformer");
#else
LOG(FATAL) << "fused_multi_transformer_xpu is not supported since it's not "
"compiled with XPU_XFT";
#endif
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_multi_transformer_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::FusedMultiTransformerXpuKernel,
float,
phi::dtype::float16) {
kernel->InputAt(20).SetBackend(phi::Backend::CPU);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册