未验证 提交 58afe45c 编写于 作者: C Cwndmiao 提交者: GitHub

[LITE][XPU] fuse q/k/v in multi_encoder (#4141)

* [LITE][XPU] fuse q/k/v in multi_encoder

* test=develop, test=xpu

* test=develop, test=xpu
上级 087aac4e
...@@ -519,14 +519,109 @@ class XPUMultiEncoderFuser { ...@@ -519,14 +519,109 @@ class XPUMultiEncoderFuser {
op_desc.SetAttr<std::string>("precision", op_desc.SetAttr<std::string>("precision",
(fc_int31_ids_.empty() ? "int16" : "int31")); (fc_int31_ids_.empty() ? "int16" : "int31"));
// check q/k/v fusion
bool enable_qkv_fusion = false;
if (!fc_int31_ids_.empty()) {
int head_num = first_encoder_op_info->GetAttr<int>("head_num");
int size_per_head = first_encoder_op_info->GetAttr<int>("size_per_head");
if (head_num * size_per_head <= 128) {
enable_qkv_fusion = true;
}
}
op_desc.SetAttr<bool>("enable_qkv_fusion", enable_qkv_fusion);
auto* scope = multi_encoder_stmt->op()->scope(); auto* scope = multi_encoder_stmt->op()->scope();
std::vector<float> fc_weight_max(arg_map["FCWeight"].size()); std::vector<float> fc_weight_max(arg_map["FCWeight"].size());
auto& fc_weight_names = arg_map["FCWeight"]; auto& fc_weight_names = arg_map["FCWeight"];
for (size_t i = 0; i < fc_weight_names.size(); ++i) { for (size_t i = 0; i < fc_weight_names.size(); ++i) {
if (enable_qkv_fusion && (i % 6 == 0)) {
// q/k/v FCWeight fusion
auto* weight_q = scope->FindMutableTensor(fc_weight_names[i]);
auto* weight_k = scope->FindMutableTensor(fc_weight_names[i + 1]);
auto* weight_v = scope->FindMutableTensor(fc_weight_names[i + 2]);
auto weight_q_dims = weight_q->dims();
auto weight_k_dims = weight_k->dims();
auto weight_v_dims = weight_v->dims();
int weight_q_len = weight_q->numel();
int weight_k_len = weight_k->numel();
int weight_v_len = weight_v->numel();
float* weight_q_on_host = weight_q->mutable_data<float>();
float* weight_k_on_host = weight_k->mutable_data<float>();
float* weight_v_on_host = weight_v->mutable_data<float>();
int qkv_len = weight_q_len + weight_k_len + weight_v_len;
int qkv_offset = 0;
CHECK_EQ(weight_q_dims[0], weight_k_dims[0]);
CHECK_EQ(weight_q_dims[0], weight_v_dims[0]);
// 1. transpose
std::unique_ptr<float[]> weight_q_trans(new float[weight_q_len]);
std::unique_ptr<float[]> weight_k_trans(new float[weight_k_len]);
std::unique_ptr<float[]> weight_v_trans(new float[weight_v_len]);
std::unique_ptr<float[]> weight_qkv_trans(new float[qkv_len]);
paddle::lite::xpu::math::Transpose(weight_q_on_host,
weight_q_trans.get(),
weight_q_dims[0],
weight_q_dims[1]);
paddle::lite::xpu::math::Transpose(weight_k_on_host,
weight_k_trans.get(),
weight_k_dims[0],
weight_k_dims[1]);
paddle::lite::xpu::math::Transpose(weight_v_on_host,
weight_v_trans.get(),
weight_v_dims[0],
weight_v_dims[1]);
// 2. concat
memcpy(weight_qkv_trans.get() + qkv_offset,
weight_q_trans.get(),
weight_q_len * sizeof(float));
qkv_offset += weight_q_len;
memcpy(weight_qkv_trans.get() + qkv_offset,
weight_k_trans.get(),
weight_k_len * sizeof(float));
qkv_offset += weight_k_len;
memcpy(weight_qkv_trans.get() + qkv_offset,
weight_v_trans.get(),
weight_v_len * sizeof(float));
qkv_offset += weight_v_len;
CHECK_EQ(qkv_offset, qkv_len);
weight_q->Resize(
{weight_q_dims[1] + weight_k_dims[1] + weight_v_dims[1],
weight_q_dims[0]});
// 3. int31 or int16
float max_f = paddle::lite::xpu::math::FindMaxAbs(
weight_qkv_trans.get(), qkv_len);
fc_weight_max[i] = max_f;
if (fc_int31_ids_.find(i % 6) != fc_int31_ids_.end()) {
VLOG(3) << "Use FC-int31 in QKV fused FC-" << i << ", " << i / 6
<< "-" << i % 6;
memcpy(weight_q->mutable_data<float>(),
weight_qkv_trans.get(),
qkv_len * sizeof(float));
} else {
std::unique_ptr<int16_t[]> weight_qkv_trans_int16(
new int16_t[qkv_len]);
paddle::lite::xpu::math::ConvertFP32ToInt16(
weight_qkv_trans.get(),
weight_qkv_trans_int16.get(),
max_f,
qkv_len);
memcpy(weight_q->mutable_data<float>(),
weight_qkv_trans_int16.get(),
qkv_len * sizeof(int16_t));
}
continue;
}
// no q/k/v fusion
auto* weight_t = scope->FindMutableTensor(fc_weight_names[i]); auto* weight_t = scope->FindMutableTensor(fc_weight_names[i]);
auto weight_dims = weight_t->dims(); auto weight_dims = weight_t->dims();
int weight_len = weight_t->numel(); int weight_len = weight_t->numel();
float* weight_on_host = weight_t->mutable_data<float>(); float* weight_on_host = weight_t->mutable_data<float>();
float max_f = float max_f =
paddle::lite::xpu::math::FindMaxAbs(weight_on_host, weight_len); paddle::lite::xpu::math::FindMaxAbs(weight_on_host, weight_len);
// i ranges from 0 to 6*encoder_num, so we need to do i%6 to get relative // i ranges from 0 to 6*encoder_num, so we need to do i%6 to get relative
...@@ -559,6 +654,50 @@ class XPUMultiEncoderFuser { ...@@ -559,6 +654,50 @@ class XPUMultiEncoderFuser {
fc_weight_max[i] = max_f; fc_weight_max[i] = max_f;
} }
auto& fc_bias_names = arg_map["FCBias"];
for (size_t i = 0; enable_qkv_fusion && i < fc_bias_names.size(); i += 6) {
// q/k/v FCBias fusion
VLOG(3) << "Copy bias in QKV fused FC-" << i << ", " << i / 6 << "-"
<< i % 6;
auto* bias_q = scope->FindMutableTensor(fc_bias_names[i]);
auto* bias_k = scope->FindMutableTensor(fc_bias_names[i + 1]);
auto* bias_v = scope->FindMutableTensor(fc_bias_names[i + 2]);
auto bias_q_dims = bias_q->dims();
auto bias_k_dims = bias_k->dims();
auto bias_v_dims = bias_v->dims();
int bias_q_len = bias_q->numel();
int bias_k_len = bias_k->numel();
int bias_v_len = bias_v->numel();
float* bias_q_on_host = bias_q->mutable_data<float>();
float* bias_k_on_host = bias_k->mutable_data<float>();
float* bias_v_on_host = bias_v->mutable_data<float>();
int qkv_len = bias_q_len + bias_k_len + bias_v_len;
int qkv_offset = 0;
CHECK_EQ(bias_q_dims.size(), 1);
CHECK_EQ(bias_k_dims.size(), 1);
CHECK_EQ(bias_v_dims.size(), 1);
std::unique_ptr<float[]> bias_qkv(new float[qkv_len]);
memcpy(bias_qkv.get() + qkv_offset,
bias_q_on_host,
bias_q_len * sizeof(float));
qkv_offset += bias_q_len;
memcpy(bias_qkv.get() + qkv_offset,
bias_k_on_host,
bias_k_len * sizeof(float));
qkv_offset += bias_k_len;
memcpy(bias_qkv.get() + qkv_offset,
bias_v_on_host,
bias_v_len * sizeof(float));
qkv_offset += bias_v_len;
CHECK_EQ(qkv_offset, qkv_len);
bias_q->Resize({qkv_len});
memcpy(bias_q->mutable_data<float>(),
bias_qkv.get(),
qkv_len * sizeof(float));
}
std::string max_name = "encoder_max"; std::string max_name = "encoder_max";
auto* max_filter_node = graph->NewArgumentNode(max_name); auto* max_filter_node = graph->NewArgumentNode(max_name);
max_filter_node->arg()->is_weight = true; max_filter_node->arg()->is_weight = true;
......
...@@ -48,6 +48,7 @@ void XPUMultiEncoderCompute::Run() { ...@@ -48,6 +48,7 @@ void XPUMultiEncoderCompute::Run() {
int seq_len = param.input->dims()[1]; int seq_len = param.input->dims()[1];
int r = -1; int r = -1;
if (param.precision == "int31") { if (param.precision == "int31") {
ctx.GetRawContext()->qkv_fusion = param.enable_qkv_fusion;
r = xdnn::bert_encoder_transformer_int31( r = xdnn::bert_encoder_transformer_int31(
ctx.GetRawContext(), /* context */ ctx.GetRawContext(), /* context */
batch_size, /* batch_size */ batch_size, /* batch_size */
......
...@@ -69,6 +69,7 @@ bool XPUMultiEncoderOp::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -69,6 +69,7 @@ bool XPUMultiEncoderOp::AttachImpl(const cpp::OpDesc& op_desc,
param_.size_per_head = op_desc.GetAttr<int>("size_per_head"); param_.size_per_head = op_desc.GetAttr<int>("size_per_head");
param_.act_type = op_desc.GetAttr<std::string>("act_type"); param_.act_type = op_desc.GetAttr<std::string>("act_type");
param_.precision = op_desc.GetAttr<std::string>("precision"); param_.precision = op_desc.GetAttr<std::string>("precision");
param_.enable_qkv_fusion = op_desc.GetAttr<bool>("enable_qkv_fusion");
return true; return true;
} }
......
...@@ -1661,6 +1661,7 @@ struct XPUMultiEncoderParam : ParamBase { ...@@ -1661,6 +1661,7 @@ struct XPUMultiEncoderParam : ParamBase {
int size_per_head{}; int size_per_head{};
std::string act_type{}; std::string act_type{};
std::string precision{}; std::string precision{};
bool enable_qkv_fusion{false};
}; };
struct XPUEmbeddingWithEltwiseAddParam : ParamBase { struct XPUEmbeddingWithEltwiseAddParam : ParamBase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册