未验证 提交 1c97aa69 编写于 作者: Z zhupengyang 提交者: GitHub

xpu quant weight only (#53306)

上级 2d17df97
...@@ -280,6 +280,8 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, ...@@ -280,6 +280,8 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
with_time_step, with_time_step,
with_seq_lengths, with_seq_lengths,
with_src_mask); with_src_mask);
int quant_weight_bits =
Has("quant_weight_bits") ? Get<int>("quant_weight_bits") : -1;
int found_subgraph_count = 0; int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
...@@ -312,36 +314,39 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, ...@@ -312,36 +314,39 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
// quant weight nodes // quant weight nodes
// w_nodes_vec: [QKVW, OutLinearW, FFN1Weight, FFN2Weight] // w_nodes_vec: [QKVW, OutLinearW, FFN1Weight, FFN2Weight]
std::vector<std::vector<Node*>> w_nodes_vec(4); std::vector<std::vector<Node*>> w_nodes_vec(4);
std::vector<std::vector<Node*>> w_int16_nodes_vec(4); std::vector<std::vector<Node*>> w_intx_nodes_vec(4);
std::vector<std::vector<Node*>> w_max_nodes_vec(4); std::vector<std::vector<Node*>> w_max_nodes_vec(4);
std::vector<std::vector<std::string>> w_int16_names_vec(4); std::vector<std::vector<std::string>> w_intx_names_vec(4);
std::vector<std::vector<std::string>> w_max_names_vec(4); std::vector<std::vector<std::string>> w_max_names_vec(4);
auto quant_func = [&](const std::string& input_name, auto quant_func = [&](const std::string& input_name,
std::vector<Node*>* w_nodes, std::vector<Node*>* w_nodes,
std::vector<Node*>* w_int16_nodes, std::vector<Node*>* w_intx_nodes,
std::vector<Node*>* w_max_nodes, std::vector<Node*>* w_max_nodes,
std::vector<std::string>* w_int16_names, std::vector<std::string>* w_intx_names,
std::vector<std::string>* w_max_names, std::vector<std::string>* w_max_names,
bool need_transpose) { bool need_transpose) {
typedef int16_t TW;
auto w_names = fused_mt->Op()->Input(input_name); auto w_names = fused_mt->Op()->Input(input_name);
for (auto w_name : w_names) { for (auto w_name : w_names) {
Node* w_node = FindNodeWithName(graph, w_name); Node* w_node = FindNodeWithName(graph, w_name);
Node* w_int16 = nullptr; Node* w_intx = nullptr;
Node* w_max = nullptr; Node* w_max = nullptr;
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
w_node, w_node,
nullptr, nullptr,
platform::errors::Fatal("w node should not be nullptr")); platform::errors::Fatal("w node should not be nullptr"));
PrepareWeight<TW>( if (quant_weight_bits == 8) {
graph, scope, block, w_node, &w_int16, &w_max, need_transpose); PrepareWeight<int8_t>(
graph, scope, block, w_node, &w_intx, &w_max, need_transpose);
} else {
PrepareWeight<int16_t>(
graph, scope, block, w_node, &w_intx, &w_max, need_transpose);
}
w_nodes->push_back(w_node); w_nodes->push_back(w_node);
w_int16_nodes->push_back(w_int16); w_intx_nodes->push_back(w_intx);
w_max_nodes->push_back(w_max); w_max_nodes->push_back(w_max);
} }
for (size_t i = 0; i < w_names.size(); ++i) { for (size_t i = 0; i < w_names.size(); ++i) {
w_int16_names->push_back(w_int16_nodes->at(i)->Name()); w_intx_names->push_back(w_intx_nodes->at(i)->Name());
w_max_names->push_back(w_max_nodes->at(i)->Name()); w_max_names->push_back(w_max_nodes->at(i)->Name());
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -353,11 +358,11 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, ...@@ -353,11 +358,11 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
static_cast<int>(w_nodes->size()))); static_cast<int>(w_nodes->size())));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
w_names.size(), w_names.size(),
w_int16_nodes->size(), w_intx_nodes->size(),
platform::errors::Fatal( platform::errors::Fatal(
"The size of w_names(%d) should be equal to w_int16_nodes(%d)", "The size of w_names(%d) should be equal to w_intx_nodes(%d)",
static_cast<int>(w_names.size()), static_cast<int>(w_names.size()),
static_cast<int>(w_int16_nodes->size()))); static_cast<int>(w_intx_nodes->size())));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
w_names.size(), w_names.size(),
w_max_nodes->size(), w_max_nodes->size(),
...@@ -367,11 +372,11 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, ...@@ -367,11 +372,11 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
static_cast<int>(w_max_nodes->size()))); static_cast<int>(w_max_nodes->size())));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
w_names.size(), w_names.size(),
w_int16_names->size(), w_intx_names->size(),
platform::errors::Fatal( platform::errors::Fatal(
"The size of w_names(%d) should be equal to w_int16_names(%d)", "The size of w_names(%d) should be equal to w_intx_names(%d)",
static_cast<int>(w_names.size()), static_cast<int>(w_names.size()),
static_cast<int>(w_int16_names->size()))); static_cast<int>(w_intx_names->size())));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
w_names.size(), w_names.size(),
w_max_names->size(), w_max_names->size(),
...@@ -382,30 +387,30 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, ...@@ -382,30 +387,30 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
}; };
quant_func("QKVW", quant_func("QKVW",
&(w_nodes_vec[0]), &(w_nodes_vec[0]),
&(w_int16_nodes_vec[0]), &(w_intx_nodes_vec[0]),
&(w_max_nodes_vec[0]), &(w_max_nodes_vec[0]),
&(w_int16_names_vec[0]), &(w_intx_names_vec[0]),
&(w_max_names_vec[0]), &(w_max_names_vec[0]),
false); false);
quant_func("OutLinearW", quant_func("OutLinearW",
&(w_nodes_vec[1]), &(w_nodes_vec[1]),
&(w_int16_nodes_vec[1]), &(w_intx_nodes_vec[1]),
&(w_max_nodes_vec[1]), &(w_max_nodes_vec[1]),
&(w_int16_names_vec[1]), &(w_intx_names_vec[1]),
&(w_max_names_vec[1]), &(w_max_names_vec[1]),
true); true);
quant_func("FFN1Weight", quant_func("FFN1Weight",
&(w_nodes_vec[2]), &(w_nodes_vec[2]),
&(w_int16_nodes_vec[2]), &(w_intx_nodes_vec[2]),
&(w_max_nodes_vec[2]), &(w_max_nodes_vec[2]),
&(w_int16_names_vec[2]), &(w_intx_names_vec[2]),
&(w_max_names_vec[2]), &(w_max_names_vec[2]),
true); true);
quant_func("FFN2Weight", quant_func("FFN2Weight",
&(w_nodes_vec[3]), &(w_nodes_vec[3]),
&(w_int16_nodes_vec[3]), &(w_intx_nodes_vec[3]),
&(w_max_nodes_vec[3]), &(w_max_nodes_vec[3]),
&(w_int16_names_vec[3]), &(w_intx_names_vec[3]),
&(w_max_names_vec[3]), &(w_max_names_vec[3]),
true); true);
...@@ -482,13 +487,13 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, ...@@ -482,13 +487,13 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
name_caches.at("CacheKVOut")); name_caches.at("CacheKVOut"));
fused_mt_xpu_op_desc->SetOutput("out", name_caches.at("Out")); fused_mt_xpu_op_desc->SetOutput("out", name_caches.at("Out"));
fused_mt_xpu_op_desc->SetInput("qkvw", w_int16_names_vec[0]); fused_mt_xpu_op_desc->SetInput("qkvw", w_intx_names_vec[0]);
fused_mt_xpu_op_desc->SetInput("qkvw_max", w_max_names_vec[0]); fused_mt_xpu_op_desc->SetInput("qkvw_max", w_max_names_vec[0]);
fused_mt_xpu_op_desc->SetInput("out_linear_w", w_int16_names_vec[1]); fused_mt_xpu_op_desc->SetInput("out_linear_w", w_intx_names_vec[1]);
fused_mt_xpu_op_desc->SetInput("out_linear_wmax", w_max_names_vec[1]); fused_mt_xpu_op_desc->SetInput("out_linear_wmax", w_max_names_vec[1]);
fused_mt_xpu_op_desc->SetInput("ffn1_weight", w_int16_names_vec[2]); fused_mt_xpu_op_desc->SetInput("ffn1_weight", w_intx_names_vec[2]);
fused_mt_xpu_op_desc->SetInput("ffn1_weight_max", w_max_names_vec[2]); fused_mt_xpu_op_desc->SetInput("ffn1_weight_max", w_max_names_vec[2]);
fused_mt_xpu_op_desc->SetInput("ffn2_weight", w_int16_names_vec[3]); fused_mt_xpu_op_desc->SetInput("ffn2_weight", w_intx_names_vec[3]);
fused_mt_xpu_op_desc->SetInput("ffn2_weight_max", w_max_names_vec[3]); fused_mt_xpu_op_desc->SetInput("ffn2_weight_max", w_max_names_vec[3]);
if (!fused_mt_xpu_op_desc->HasAttr("rotary_emb_dims")) { if (!fused_mt_xpu_op_desc->HasAttr("rotary_emb_dims")) {
fused_mt_xpu_op_desc->SetAttr("rotary_emb_dims", 0); fused_mt_xpu_op_desc->SetAttr("rotary_emb_dims", 0);
...@@ -501,7 +506,7 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, ...@@ -501,7 +506,7 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
} }
// link int16 format of QKVW/OutLinearW/FFN1Weight/FFN2Weight to // link int16 format of QKVW/OutLinearW/FFN1Weight/FFN2Weight to
// fused_mt_xpu // fused_mt_xpu
for (auto nodes : w_int16_nodes_vec) { for (auto nodes : w_intx_nodes_vec) {
for (auto node : nodes) { for (auto node : nodes) {
IR_NODE_LINK_TO(node, fused_mt); IR_NODE_LINK_TO(node, fused_mt);
} }
......
...@@ -193,6 +193,13 @@ template void PrepareWeight<int16_t>(Graph* graph, ...@@ -193,6 +193,13 @@ template void PrepareWeight<int16_t>(Graph* graph,
Node** dst, Node** dst,
Node** dst_max, Node** dst_max,
bool transpose); bool transpose);
template void PrepareWeight<int8_t>(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* src,
Node** dst,
Node** dst_max,
bool transpose);
void PrepareBias( void PrepareBias(
Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst) { Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst) {
......
...@@ -207,6 +207,16 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr, ...@@ -207,6 +207,16 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr,
} }
} }
template <>
void QuantFP32ToIntX<int8_t>(const float* src_ptr,
int8_t* dst_ptr,
float max_val,
int numel) {
for (int i = 0; i < numel; i++) {
dst_ptr[i] = Fp32ToIntx<int8_t, 127>(src_ptr[i], max_val);
}
}
template <typename T> template <typename T>
void PrepareWeight(phi::DenseTensor* weight, void PrepareWeight(phi::DenseTensor* weight,
phi::DenseTensor* weight_max, phi::DenseTensor* weight_max,
...@@ -253,6 +263,9 @@ void PrepareWeight(phi::DenseTensor* weight, ...@@ -253,6 +263,9 @@ void PrepareWeight(phi::DenseTensor* weight,
template void PrepareWeight<int16_t>(phi::DenseTensor* weight, template void PrepareWeight<int16_t>(phi::DenseTensor* weight,
phi::DenseTensor* weight_max, phi::DenseTensor* weight_max,
bool transpose); bool transpose);
template void PrepareWeight<int8_t>(phi::DenseTensor* weight,
phi::DenseTensor* weight_max,
bool transpose);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -289,6 +289,12 @@ struct Argument { ...@@ -289,6 +289,12 @@ struct Argument {
DECL_ARGUMENT_FIELD(xpu_adaptive_seqlen, XpuAdaptiveSeqlen, bool); DECL_ARGUMENT_FIELD(xpu_adaptive_seqlen, XpuAdaptiveSeqlen, bool);
DECL_ARGUMENT_FIELD(xpu_device_id, XpuDeviceId, int); DECL_ARGUMENT_FIELD(xpu_device_id, XpuDeviceId, int);
DECL_ARGUMENT_FIELD(xpu_enable_multi_stream, XpuEnableMultiStream, bool); DECL_ARGUMENT_FIELD(xpu_enable_multi_stream, XpuEnableMultiStream, bool);
DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_weight_bits,
XpuQuantPostDynamicWeightBits,
int);
DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_op_types,
XpuQuantPostDynamicOpTypss,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(use_opencl, UseOpenCL, bool); DECL_ARGUMENT_FIELD(use_opencl, UseOpenCL, bool);
......
...@@ -308,6 +308,14 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -308,6 +308,14 @@ void IRPassManager::CreatePasses(Argument *argument,
} }
bool use_fc_padding = !fc_mkldnn_pass && argument->use_fc_padding(); bool use_fc_padding = !fc_mkldnn_pass && argument->use_fc_padding();
pass->Set("use_fc_padding", new bool(use_fc_padding)); pass->Set("use_fc_padding", new bool(use_fc_padding));
} else if (pass_name == "fused_multi_transformer_xpu_quant_pass") {
auto op_types = argument->xpu_quant_post_dynamic_op_types();
if (std::count(op_types.begin(),
op_types.end(),
"fused_multi_transformer") > 0) {
pass->Set("quant_weight_bits",
new int(argument->xpu_quant_post_dynamic_weight_bits()));
}
} }
pre_pass = pass_name; pre_pass = pass_name;
......
...@@ -196,6 +196,14 @@ void AnalysisConfig::SetXpuDeviceId(int device_id) { ...@@ -196,6 +196,14 @@ void AnalysisConfig::SetXpuDeviceId(int device_id) {
Update(); Update();
} }
void AnalysisConfig::SetXpuConfig(
int quant_post_dynamic_weight_bits,
const std::vector<std::string> &quant_post_dynamic_op_types) {
xpu_quant_post_dynamic_weight_bits_ = quant_post_dynamic_weight_bits;
xpu_quant_post_dynamic_op_types_ = quant_post_dynamic_op_types;
Update();
}
void AnalysisConfig::EnableCustomDevice(const std::string &device_type, void AnalysisConfig::EnableCustomDevice(const std::string &device_type,
int device_id, int device_id,
Precision precision_mode) { Precision precision_mode) {
...@@ -489,6 +497,8 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -489,6 +497,8 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(xpu_precision_); CP_MEMBER(xpu_precision_);
CP_MEMBER(xpu_adaptive_seqlen_); CP_MEMBER(xpu_adaptive_seqlen_);
CP_MEMBER(xpu_enable_multi_stream_); CP_MEMBER(xpu_enable_multi_stream_);
CP_MEMBER(xpu_quant_post_dynamic_weight_bits_);
CP_MEMBER(xpu_quant_post_dynamic_op_types_);
// Lite OpenCL Related // Lite OpenCL Related
CP_MEMBER(use_opencl_); CP_MEMBER(use_opencl_);
...@@ -1091,6 +1101,10 @@ std::string AnalysisConfig::SerializeInfoCache() { ...@@ -1091,6 +1101,10 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << xpu_precision_; ss << xpu_precision_;
ss << xpu_adaptive_seqlen_; ss << xpu_adaptive_seqlen_;
ss << xpu_enable_multi_stream_; ss << xpu_enable_multi_stream_;
ss << xpu_quant_post_dynamic_weight_bits_;
for (auto op_type : xpu_quant_post_dynamic_op_types_) {
ss << op_type;
}
ss << use_npu_; ss << use_npu_;
ss << npu_device_id_; ss << npu_device_id_;
...@@ -1331,6 +1345,13 @@ std::string AnalysisConfig::Summary() { ...@@ -1331,6 +1345,13 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"xpu_device_id", std::to_string(xpu_device_id_)}); os.InsertRow({"xpu_device_id", std::to_string(xpu_device_id_)});
os.InsertRow( os.InsertRow(
{"xpu_l3_workspace_size", std::to_string(xpu_l3_workspace_size_)}); {"xpu_l3_workspace_size", std::to_string(xpu_l3_workspace_size_)});
os.InsertRow({"xpu_quant_post_dynamic_weight_bits",
std::to_string(xpu_quant_post_dynamic_weight_bits_)});
std::vector<std::string> op_types{"xpu_quant_post_dynamic_op_types"};
for (auto op_type : xpu_quant_post_dynamic_op_types_) {
op_types.push_back(op_type);
}
os.InsertRow(op_types);
} }
os.InsetDivider(); os.InsetDivider();
......
...@@ -1426,6 +1426,10 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1426,6 +1426,10 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetXpuAdaptiveSeqlen(config_.xpu_adaptive_seqlen_); argument_->SetXpuAdaptiveSeqlen(config_.xpu_adaptive_seqlen_);
argument_->SetXpuDeviceId(config_.xpu_device_id_); argument_->SetXpuDeviceId(config_.xpu_device_id_);
argument_->SetXpuEnableMultiStream(config_.xpu_enable_multi_stream_); argument_->SetXpuEnableMultiStream(config_.xpu_enable_multi_stream_);
argument_->SetXpuQuantPostDynamicWeightBits(
config_.xpu_quant_post_dynamic_weight_bits_);
argument_->SetXpuQuantPostDynamicOpTypss(
config_.xpu_quant_post_dynamic_op_types_);
#endif #endif
auto *pass_builder = config_.pass_builder(); auto *pass_builder = config_.pass_builder();
......
...@@ -288,6 +288,18 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -288,6 +288,18 @@ struct PD_INFER_DECL AnalysisConfig {
bool adaptive_seqlen = false, bool adaptive_seqlen = false,
bool enable_multi_stream = false); bool enable_multi_stream = false);
///
/// \brief configs of XPU
///
/// \param quant_post_dynamic_weight_bits Weight bits used in dynamic post
/// quantization. Optional value: -1, 8, 16. Default value is -1, means using
/// the recommended way. \param quant_post_dynamic_op_types Ops used in
/// dynamic post quantization.
///
void SetXpuConfig(
int quant_post_dynamic_weight_bits = -1,
const std::vector<std::string>& quant_post_dynamic_op_types = {});
/// ///
/// \brief configs of IPU /// \brief configs of IPU
/// ///
...@@ -1181,6 +1193,8 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -1181,6 +1193,8 @@ struct PD_INFER_DECL AnalysisConfig {
std::string xpu_precision_; std::string xpu_precision_;
bool xpu_adaptive_seqlen_; bool xpu_adaptive_seqlen_;
bool xpu_enable_multi_stream_; bool xpu_enable_multi_stream_;
int xpu_quant_post_dynamic_weight_bits_{-1};
std::vector<std::string> xpu_quant_post_dynamic_op_types_;
// LITE OPENCL SETTINGS // LITE OPENCL SETTINGS
bool use_opencl_{false}; bool use_opencl_{false};
......
...@@ -767,6 +767,11 @@ void BindAnalysisConfig(py::module *m) { ...@@ -767,6 +767,11 @@ void BindAnalysisConfig(py::module *m) {
.def("set_xpu_device_id", .def("set_xpu_device_id",
&AnalysisConfig::SetXpuDeviceId, &AnalysisConfig::SetXpuDeviceId,
py::arg("device_id") = 0) py::arg("device_id") = 0)
.def(
"set_xpu_config",
&AnalysisConfig::SetXpuConfig,
py::arg("quant_post_dynamic_weight_bits") = -1,
py::arg("quant_post_dynamic_op_types") = std::vector<std::string>({}))
.def("enable_custom_device", .def("enable_custom_device",
&AnalysisConfig::EnableCustomDevice, &AnalysisConfig::EnableCustomDevice,
py::arg("device_type"), py::arg("device_type"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册