From 58afe45cb6bbc8ec59a222126f95df9c8bd01d63 Mon Sep 17 00:00:00 2001 From: Cwndmiao Date: Wed, 26 Aug 2020 13:26:18 +0800 Subject: [PATCH] [LITE][XPU] fuse q/k/v in multi_encoder (#4141) * [LITE][XPU] fuse q/k/v in multi_encoder * test=develop, test=xpu * test=develop, test=xpu --- .../fusion/__xpu__multi_encoder_fuse_pass.cc | 139 ++++++++++++++++++ .../xpu/__xpu__multi_encoder_compute.cc | 1 + lite/operators/__xpu__multi_encoder_op.cc | 1 + lite/operators/op_params.h | 1 + 4 files changed, 142 insertions(+) diff --git a/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc b/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc index 8546af5414..c88e576659 100644 --- a/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc +++ b/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc @@ -519,14 +519,109 @@ class XPUMultiEncoderFuser { op_desc.SetAttr("precision", (fc_int31_ids_.empty() ? "int16" : "int31")); + // check q/k/v fusion + bool enable_qkv_fusion = false; + if (!fc_int31_ids_.empty()) { + int head_num = first_encoder_op_info->GetAttr("head_num"); + int size_per_head = first_encoder_op_info->GetAttr("size_per_head"); + if (head_num * size_per_head <= 128) { + enable_qkv_fusion = true; + } + } + op_desc.SetAttr("enable_qkv_fusion", enable_qkv_fusion); + auto* scope = multi_encoder_stmt->op()->scope(); std::vector fc_weight_max(arg_map["FCWeight"].size()); auto& fc_weight_names = arg_map["FCWeight"]; for (size_t i = 0; i < fc_weight_names.size(); ++i) { + if (enable_qkv_fusion && (i % 6 == 0)) { + // q/k/v FCWeight fusion + auto* weight_q = scope->FindMutableTensor(fc_weight_names[i]); + auto* weight_k = scope->FindMutableTensor(fc_weight_names[i + 1]); + auto* weight_v = scope->FindMutableTensor(fc_weight_names[i + 2]); + auto weight_q_dims = weight_q->dims(); + auto weight_k_dims = weight_k->dims(); + auto weight_v_dims = weight_v->dims(); + int weight_q_len = weight_q->numel(); + int weight_k_len = weight_k->numel(); + int weight_v_len = weight_v->numel(); + float* weight_q_on_host = weight_q->mutable_data(); + float* weight_k_on_host = weight_k->mutable_data(); + float* weight_v_on_host = weight_v->mutable_data(); + int qkv_len = weight_q_len + weight_k_len + weight_v_len; + int qkv_offset = 0; + CHECK_EQ(weight_q_dims[0], weight_k_dims[0]); + CHECK_EQ(weight_q_dims[0], weight_v_dims[0]); + + // 1. transpose + std::unique_ptr weight_q_trans(new float[weight_q_len]); + std::unique_ptr weight_k_trans(new float[weight_k_len]); + std::unique_ptr weight_v_trans(new float[weight_v_len]); + std::unique_ptr weight_qkv_trans(new float[qkv_len]); + paddle::lite::xpu::math::Transpose(weight_q_on_host, + weight_q_trans.get(), + weight_q_dims[0], + weight_q_dims[1]); + paddle::lite::xpu::math::Transpose(weight_k_on_host, + weight_k_trans.get(), + weight_k_dims[0], + weight_k_dims[1]); + paddle::lite::xpu::math::Transpose(weight_v_on_host, + weight_v_trans.get(), + weight_v_dims[0], + weight_v_dims[1]); + + // 2. concat + memcpy(weight_qkv_trans.get() + qkv_offset, + weight_q_trans.get(), + weight_q_len * sizeof(float)); + qkv_offset += weight_q_len; + memcpy(weight_qkv_trans.get() + qkv_offset, + weight_k_trans.get(), + weight_k_len * sizeof(float)); + qkv_offset += weight_k_len; + memcpy(weight_qkv_trans.get() + qkv_offset, + weight_v_trans.get(), + weight_v_len * sizeof(float)); + qkv_offset += weight_v_len; + CHECK_EQ(qkv_offset, qkv_len); + + weight_q->Resize( + {weight_q_dims[1] + weight_k_dims[1] + weight_v_dims[1], + weight_q_dims[0]}); + + // 3. int31 or int16 + float max_f = paddle::lite::xpu::math::FindMaxAbs( + weight_qkv_trans.get(), qkv_len); + fc_weight_max[i] = max_f; + if (fc_int31_ids_.find(i % 6) != fc_int31_ids_.end()) { + VLOG(3) << "Use FC-int31 in QKV fused FC-" << i << ", " << i / 6 + << "-" << i % 6; + memcpy(weight_q->mutable_data(), + weight_qkv_trans.get(), + qkv_len * sizeof(float)); + } else { + std::unique_ptr weight_qkv_trans_int16( + new int16_t[qkv_len]); + paddle::lite::xpu::math::ConvertFP32ToInt16( + weight_qkv_trans.get(), + weight_qkv_trans_int16.get(), + max_f, + qkv_len); + memcpy(weight_q->mutable_data(), + weight_qkv_trans_int16.get(), + qkv_len * sizeof(int16_t)); + } + + continue; + } + + // no q/k/v fusion auto* weight_t = scope->FindMutableTensor(fc_weight_names[i]); auto weight_dims = weight_t->dims(); int weight_len = weight_t->numel(); float* weight_on_host = weight_t->mutable_data(); + float max_f = paddle::lite::xpu::math::FindMaxAbs(weight_on_host, weight_len); // i ranges from 0 to 6*encoder_num, so we need to do i%6 to get relative @@ -559,6 +654,50 @@ class XPUMultiEncoderFuser { fc_weight_max[i] = max_f; } + auto& fc_bias_names = arg_map["FCBias"]; + for (size_t i = 0; enable_qkv_fusion && i < fc_bias_names.size(); i += 6) { + // q/k/v FCBias fusion + VLOG(3) << "Copy bias in QKV fused FC-" << i << ", " << i / 6 << "-" + << i % 6; + auto* bias_q = scope->FindMutableTensor(fc_bias_names[i]); + auto* bias_k = scope->FindMutableTensor(fc_bias_names[i + 1]); + auto* bias_v = scope->FindMutableTensor(fc_bias_names[i + 2]); + auto bias_q_dims = bias_q->dims(); + auto bias_k_dims = bias_k->dims(); + auto bias_v_dims = bias_v->dims(); + int bias_q_len = bias_q->numel(); + int bias_k_len = bias_k->numel(); + int bias_v_len = bias_v->numel(); + float* bias_q_on_host = bias_q->mutable_data(); + float* bias_k_on_host = bias_k->mutable_data(); + float* bias_v_on_host = bias_v->mutable_data(); + int qkv_len = bias_q_len + bias_k_len + bias_v_len; + int qkv_offset = 0; + CHECK_EQ(bias_q_dims.size(), 1); + CHECK_EQ(bias_k_dims.size(), 1); + CHECK_EQ(bias_v_dims.size(), 1); + + std::unique_ptr bias_qkv(new float[qkv_len]); + memcpy(bias_qkv.get() + qkv_offset, + bias_q_on_host, + bias_q_len * sizeof(float)); + qkv_offset += bias_q_len; + memcpy(bias_qkv.get() + qkv_offset, + bias_k_on_host, + bias_k_len * sizeof(float)); + qkv_offset += bias_k_len; + memcpy(bias_qkv.get() + qkv_offset, + bias_v_on_host, + bias_v_len * sizeof(float)); + qkv_offset += bias_v_len; + CHECK_EQ(qkv_offset, qkv_len); + + bias_q->Resize({qkv_len}); + memcpy(bias_q->mutable_data(), + bias_qkv.get(), + qkv_len * sizeof(float)); + } + std::string max_name = "encoder_max"; auto* max_filter_node = graph->NewArgumentNode(max_name); max_filter_node->arg()->is_weight = true; diff --git a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc index 781a548241..5a2f0ab2f0 100644 --- a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc +++ b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc @@ -48,6 +48,7 @@ void XPUMultiEncoderCompute::Run() { int seq_len = param.input->dims()[1]; int r = -1; if (param.precision == "int31") { + ctx.GetRawContext()->qkv_fusion = param.enable_qkv_fusion; r = xdnn::bert_encoder_transformer_int31( ctx.GetRawContext(), /* context */ batch_size, /* batch_size */ diff --git a/lite/operators/__xpu__multi_encoder_op.cc b/lite/operators/__xpu__multi_encoder_op.cc index 5a1d2cb82e..d2035019b4 100644 --- a/lite/operators/__xpu__multi_encoder_op.cc +++ b/lite/operators/__xpu__multi_encoder_op.cc @@ -69,6 +69,7 @@ bool XPUMultiEncoderOp::AttachImpl(const cpp::OpDesc& op_desc, param_.size_per_head = op_desc.GetAttr("size_per_head"); param_.act_type = op_desc.GetAttr("act_type"); param_.precision = op_desc.GetAttr("precision"); + param_.enable_qkv_fusion = op_desc.GetAttr("enable_qkv_fusion"); return true; } diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 93986a1903..3e68bc1631 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1661,6 +1661,7 @@ struct XPUMultiEncoderParam : ParamBase { int size_per_head{}; std::string act_type{}; std::string precision{}; + bool enable_qkv_fusion{false}; }; struct XPUEmbeddingWithEltwiseAddParam : ParamBase { -- GitLab