未验证 提交 1c97aa69 编写于 作者: Z zhupengyang 提交者: GitHub

xpu quant weight only (#53306)

上级 2d17df97
......@@ -280,6 +280,8 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
with_time_step,
with_seq_lengths,
with_src_mask);
int quant_weight_bits =
Has("quant_weight_bits") ? Get<int>("quant_weight_bits") : -1;
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
......@@ -312,36 +314,39 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
// quant weight nodes
// w_nodes_vec: [QKVW, OutLinearW, FFN1Weight, FFN2Weight]
std::vector<std::vector<Node*>> w_nodes_vec(4);
std::vector<std::vector<Node*>> w_int16_nodes_vec(4);
std::vector<std::vector<Node*>> w_intx_nodes_vec(4);
std::vector<std::vector<Node*>> w_max_nodes_vec(4);
std::vector<std::vector<std::string>> w_int16_names_vec(4);
std::vector<std::vector<std::string>> w_intx_names_vec(4);
std::vector<std::vector<std::string>> w_max_names_vec(4);
auto quant_func = [&](const std::string& input_name,
std::vector<Node*>* w_nodes,
std::vector<Node*>* w_int16_nodes,
std::vector<Node*>* w_intx_nodes,
std::vector<Node*>* w_max_nodes,
std::vector<std::string>* w_int16_names,
std::vector<std::string>* w_intx_names,
std::vector<std::string>* w_max_names,
bool need_transpose) {
typedef int16_t TW;
auto w_names = fused_mt->Op()->Input(input_name);
for (auto w_name : w_names) {
Node* w_node = FindNodeWithName(graph, w_name);
Node* w_int16 = nullptr;
Node* w_intx = nullptr;
Node* w_max = nullptr;
PADDLE_ENFORCE_NE(
w_node,
nullptr,
platform::errors::Fatal("w node should not be nullptr"));
PrepareWeight<TW>(
graph, scope, block, w_node, &w_int16, &w_max, need_transpose);
if (quant_weight_bits == 8) {
PrepareWeight<int8_t>(
graph, scope, block, w_node, &w_intx, &w_max, need_transpose);
} else {
PrepareWeight<int16_t>(
graph, scope, block, w_node, &w_intx, &w_max, need_transpose);
}
w_nodes->push_back(w_node);
w_int16_nodes->push_back(w_int16);
w_intx_nodes->push_back(w_intx);
w_max_nodes->push_back(w_max);
}
for (size_t i = 0; i < w_names.size(); ++i) {
w_int16_names->push_back(w_int16_nodes->at(i)->Name());
w_intx_names->push_back(w_intx_nodes->at(i)->Name());
w_max_names->push_back(w_max_nodes->at(i)->Name());
}
PADDLE_ENFORCE_EQ(
......@@ -353,11 +358,11 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
static_cast<int>(w_nodes->size())));
PADDLE_ENFORCE_EQ(
w_names.size(),
w_int16_nodes->size(),
w_intx_nodes->size(),
platform::errors::Fatal(
"The size of w_names(%d) should be equal to w_int16_nodes(%d)",
"The size of w_names(%d) should be equal to w_intx_nodes(%d)",
static_cast<int>(w_names.size()),
static_cast<int>(w_int16_nodes->size())));
static_cast<int>(w_intx_nodes->size())));
PADDLE_ENFORCE_EQ(
w_names.size(),
w_max_nodes->size(),
......@@ -367,11 +372,11 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
static_cast<int>(w_max_nodes->size())));
PADDLE_ENFORCE_EQ(
w_names.size(),
w_int16_names->size(),
w_intx_names->size(),
platform::errors::Fatal(
"The size of w_names(%d) should be equal to w_int16_names(%d)",
"The size of w_names(%d) should be equal to w_intx_names(%d)",
static_cast<int>(w_names.size()),
static_cast<int>(w_int16_names->size())));
static_cast<int>(w_intx_names->size())));
PADDLE_ENFORCE_EQ(
w_names.size(),
w_max_names->size(),
......@@ -382,30 +387,30 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
};
quant_func("QKVW",
&(w_nodes_vec[0]),
&(w_int16_nodes_vec[0]),
&(w_intx_nodes_vec[0]),
&(w_max_nodes_vec[0]),
&(w_int16_names_vec[0]),
&(w_intx_names_vec[0]),
&(w_max_names_vec[0]),
false);
quant_func("OutLinearW",
&(w_nodes_vec[1]),
&(w_int16_nodes_vec[1]),
&(w_intx_nodes_vec[1]),
&(w_max_nodes_vec[1]),
&(w_int16_names_vec[1]),
&(w_intx_names_vec[1]),
&(w_max_names_vec[1]),
true);
quant_func("FFN1Weight",
&(w_nodes_vec[2]),
&(w_int16_nodes_vec[2]),
&(w_intx_nodes_vec[2]),
&(w_max_nodes_vec[2]),
&(w_int16_names_vec[2]),
&(w_intx_names_vec[2]),
&(w_max_names_vec[2]),
true);
quant_func("FFN2Weight",
&(w_nodes_vec[3]),
&(w_int16_nodes_vec[3]),
&(w_intx_nodes_vec[3]),
&(w_max_nodes_vec[3]),
&(w_int16_names_vec[3]),
&(w_intx_names_vec[3]),
&(w_max_names_vec[3]),
true);
......@@ -482,13 +487,13 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
name_caches.at("CacheKVOut"));
fused_mt_xpu_op_desc->SetOutput("out", name_caches.at("Out"));
fused_mt_xpu_op_desc->SetInput("qkvw", w_int16_names_vec[0]);
fused_mt_xpu_op_desc->SetInput("qkvw", w_intx_names_vec[0]);
fused_mt_xpu_op_desc->SetInput("qkvw_max", w_max_names_vec[0]);
fused_mt_xpu_op_desc->SetInput("out_linear_w", w_int16_names_vec[1]);
fused_mt_xpu_op_desc->SetInput("out_linear_w", w_intx_names_vec[1]);
fused_mt_xpu_op_desc->SetInput("out_linear_wmax", w_max_names_vec[1]);
fused_mt_xpu_op_desc->SetInput("ffn1_weight", w_int16_names_vec[2]);
fused_mt_xpu_op_desc->SetInput("ffn1_weight", w_intx_names_vec[2]);
fused_mt_xpu_op_desc->SetInput("ffn1_weight_max", w_max_names_vec[2]);
fused_mt_xpu_op_desc->SetInput("ffn2_weight", w_int16_names_vec[3]);
fused_mt_xpu_op_desc->SetInput("ffn2_weight", w_intx_names_vec[3]);
fused_mt_xpu_op_desc->SetInput("ffn2_weight_max", w_max_names_vec[3]);
if (!fused_mt_xpu_op_desc->HasAttr("rotary_emb_dims")) {
fused_mt_xpu_op_desc->SetAttr("rotary_emb_dims", 0);
......@@ -501,7 +506,7 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
}
// link int16 format of QKVW/OutLinearW/FFN1Weight/FFN2Weight to
// fused_mt_xpu
for (auto nodes : w_int16_nodes_vec) {
for (auto nodes : w_intx_nodes_vec) {
for (auto node : nodes) {
IR_NODE_LINK_TO(node, fused_mt);
}
......
......@@ -193,6 +193,13 @@ template void PrepareWeight<int16_t>(Graph* graph,
Node** dst,
Node** dst_max,
bool transpose);
template void PrepareWeight<int8_t>(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* src,
Node** dst,
Node** dst_max,
bool transpose);
void PrepareBias(
Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst) {
......
......@@ -207,6 +207,16 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr,
}
}
template <>
void QuantFP32ToIntX<int8_t>(const float* src_ptr,
int8_t* dst_ptr,
float max_val,
int numel) {
for (int i = 0; i < numel; i++) {
dst_ptr[i] = Fp32ToIntx<int8_t, 127>(src_ptr[i], max_val);
}
}
template <typename T>
void PrepareWeight(phi::DenseTensor* weight,
phi::DenseTensor* weight_max,
......@@ -253,6 +263,9 @@ void PrepareWeight(phi::DenseTensor* weight,
template void PrepareWeight<int16_t>(phi::DenseTensor* weight,
phi::DenseTensor* weight_max,
bool transpose);
template void PrepareWeight<int8_t>(phi::DenseTensor* weight,
phi::DenseTensor* weight_max,
bool transpose);
} // namespace ir
} // namespace framework
......
......@@ -289,6 +289,12 @@ struct Argument {
DECL_ARGUMENT_FIELD(xpu_adaptive_seqlen, XpuAdaptiveSeqlen, bool);
DECL_ARGUMENT_FIELD(xpu_device_id, XpuDeviceId, int);
DECL_ARGUMENT_FIELD(xpu_enable_multi_stream, XpuEnableMultiStream, bool);
DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_weight_bits,
XpuQuantPostDynamicWeightBits,
int);
DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_op_types,
XpuQuantPostDynamicOpTypss,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(use_opencl, UseOpenCL, bool);
......
......@@ -308,6 +308,14 @@ void IRPassManager::CreatePasses(Argument *argument,
}
bool use_fc_padding = !fc_mkldnn_pass && argument->use_fc_padding();
pass->Set("use_fc_padding", new bool(use_fc_padding));
} else if (pass_name == "fused_multi_transformer_xpu_quant_pass") {
auto op_types = argument->xpu_quant_post_dynamic_op_types();
if (std::count(op_types.begin(),
op_types.end(),
"fused_multi_transformer") > 0) {
pass->Set("quant_weight_bits",
new int(argument->xpu_quant_post_dynamic_weight_bits()));
}
}
pre_pass = pass_name;
......
......@@ -196,6 +196,14 @@ void AnalysisConfig::SetXpuDeviceId(int device_id) {
Update();
}
void AnalysisConfig::SetXpuConfig(
int quant_post_dynamic_weight_bits,
const std::vector<std::string> &quant_post_dynamic_op_types) {
xpu_quant_post_dynamic_weight_bits_ = quant_post_dynamic_weight_bits;
xpu_quant_post_dynamic_op_types_ = quant_post_dynamic_op_types;
Update();
}
void AnalysisConfig::EnableCustomDevice(const std::string &device_type,
int device_id,
Precision precision_mode) {
......@@ -489,6 +497,8 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(xpu_precision_);
CP_MEMBER(xpu_adaptive_seqlen_);
CP_MEMBER(xpu_enable_multi_stream_);
CP_MEMBER(xpu_quant_post_dynamic_weight_bits_);
CP_MEMBER(xpu_quant_post_dynamic_op_types_);
// Lite OpenCL Related
CP_MEMBER(use_opencl_);
......@@ -1091,6 +1101,10 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << xpu_precision_;
ss << xpu_adaptive_seqlen_;
ss << xpu_enable_multi_stream_;
ss << xpu_quant_post_dynamic_weight_bits_;
for (auto op_type : xpu_quant_post_dynamic_op_types_) {
ss << op_type;
}
ss << use_npu_;
ss << npu_device_id_;
......@@ -1331,6 +1345,13 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"xpu_device_id", std::to_string(xpu_device_id_)});
os.InsertRow(
{"xpu_l3_workspace_size", std::to_string(xpu_l3_workspace_size_)});
os.InsertRow({"xpu_quant_post_dynamic_weight_bits",
std::to_string(xpu_quant_post_dynamic_weight_bits_)});
std::vector<std::string> op_types{"xpu_quant_post_dynamic_op_types"};
for (auto op_type : xpu_quant_post_dynamic_op_types_) {
op_types.push_back(op_type);
}
os.InsertRow(op_types);
}
os.InsetDivider();
......
......@@ -1426,6 +1426,10 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetXpuAdaptiveSeqlen(config_.xpu_adaptive_seqlen_);
argument_->SetXpuDeviceId(config_.xpu_device_id_);
argument_->SetXpuEnableMultiStream(config_.xpu_enable_multi_stream_);
argument_->SetXpuQuantPostDynamicWeightBits(
config_.xpu_quant_post_dynamic_weight_bits_);
argument_->SetXpuQuantPostDynamicOpTypss(
config_.xpu_quant_post_dynamic_op_types_);
#endif
auto *pass_builder = config_.pass_builder();
......
......@@ -288,6 +288,18 @@ struct PD_INFER_DECL AnalysisConfig {
bool adaptive_seqlen = false,
bool enable_multi_stream = false);
///
/// \brief configs of XPU
///
/// \param quant_post_dynamic_weight_bits Weight bits used in dynamic post
/// quantization. Optional value: -1, 8, 16. Default value is -1, means using
/// the recommended way. \param quant_post_dynamic_op_types Ops used in
/// dynamic post quantization.
///
void SetXpuConfig(
int quant_post_dynamic_weight_bits = -1,
const std::vector<std::string>& quant_post_dynamic_op_types = {});
///
/// \brief configs of IPU
///
......@@ -1181,6 +1193,8 @@ struct PD_INFER_DECL AnalysisConfig {
std::string xpu_precision_;
bool xpu_adaptive_seqlen_;
bool xpu_enable_multi_stream_;
int xpu_quant_post_dynamic_weight_bits_{-1};
std::vector<std::string> xpu_quant_post_dynamic_op_types_;
// LITE OPENCL SETTINGS
bool use_opencl_{false};
......
......@@ -767,6 +767,11 @@ void BindAnalysisConfig(py::module *m) {
.def("set_xpu_device_id",
&AnalysisConfig::SetXpuDeviceId,
py::arg("device_id") = 0)
.def(
"set_xpu_config",
&AnalysisConfig::SetXpuConfig,
py::arg("quant_post_dynamic_weight_bits") = -1,
py::arg("quant_post_dynamic_op_types") = std::vector<std::string>({}))
.def("enable_custom_device",
&AnalysisConfig::EnableCustomDevice,
py::arg("device_type"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册