提交 735864a0 编写于 作者: M MyPandaShaoxiang

feat: add dyanmic quant fuse pass

上级 0a075279
...@@ -27,10 +27,24 @@ namespace mir { ...@@ -27,10 +27,24 @@ namespace mir {
void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// delete quant node // delete quant node
std::vector<std::string> quant_op_types = { std::vector<std::string> quant_op_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; "fake_quantize_abs_max",
"fake_quantize_range_abs_max",
"fake_quantize_moving_average_abs_max"};
/*
for (auto& op_type : {"conv2d", "mul", "depthwise_conv2d"}) {
for (int i = 5; i >= 1; --i){
fusion::DynamicQuantDequantOpFuser fuser("fake_quantize_abs_max", op_type,
i);
fuser(graph.get());
}
}
*/
for (auto& op_type : quant_op_types) { for (auto& op_type : quant_op_types) {
fusion::DeleteQuantOpFuser fuser(op_type); fusion::DeleteQuantOpFuser fuser(op_type);
fuser(graph.get()); fuser(graph.get());
fusion::DeleteDynamicQuantOpFuser dfuser(op_type);
dfuser(graph.get());
} }
// fuse quantized node and dequant node // fuse quantized node and dequant node
......
...@@ -77,6 +77,55 @@ cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -77,6 +77,55 @@ cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) {
return op_desc; return op_desc;
} }
void DeleteDynamicQuantOpFuser::BuildPattern() {
auto* input_act_node =
VarNode("input_act_node")->assert_is_op_input(quant_op_type_, "X");
auto* quant_node =
OpNode("quant_node", quant_op_type_)->assert_is_op(quant_op_type_);
auto* output_scale_node =
VarNode("output_scale_node")
->assert_is_op_output(quant_op_type_, "OutScale");
auto* output_act_node =
VarNode("output_act_node")->assert_is_op_output(quant_op_type_, "Out");
quant_node->LinksFrom({input_act_node});
output_scale_node->LinksFrom({quant_node});
output_act_node->LinksFrom({quant_node});
VLOG(4) << "DeleteQuantOpFuser BuildPattern quant_op_type:" << quant_op_type_;
}
void DeleteDynamicQuantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto* input_act_node = matched.at("input_act_node");
auto* quant_node = matched.at("quant_node");
auto* output_scale_node = matched.at("output_scale_node");
auto* output_act_node = matched.at("output_act_node");
// obtain values, save values and relink node
int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_node->stmt()->op()->scope();
auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto outlinks = output_act_node->outlinks;
for (auto* quantized_node : outlinks) {
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
IR_NODE_LINK_TO(input_act_node, quantized_node)
}
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
quant_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph, nodes2rm);
}
cpp::OpDesc DeleteDynamicQuantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
void DequantOpFuser::BuildPattern() { void DequantOpFuser::BuildPattern() {
std::string weight_name = ""; std::string weight_name = "";
if (quantized_op_type_ == "conv2d" || if (quantized_op_type_ == "conv2d" ||
...@@ -130,8 +179,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -130,8 +179,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
auto& valid_places = quantized_op->stmt()->op()->valid_places(); auto& valid_places = quantized_op->stmt()->op()->valid_places();
int bit_length = quantized_op->stmt()->op_info()->GetAttr<int>("bit_length"); int bit_length = quantized_op->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1); int range = ((1 << (bit_length - 1)) - 1);
float input_scale = float input_scale = 0;
if (quantized_op->stmt()->op_info()->HasAttr("input_scale")) {
input_scale =
quantized_op->stmt()->op_info()->GetAttr<float>("input_scale"); quantized_op->stmt()->op_info()->GetAttr<float>("input_scale");
}
float max_range = dequant_op->stmt()->op_info()->GetAttr<float>("max_range"); float max_range = dequant_op->stmt()->op_info()->GetAttr<float>("max_range");
float whole_weight_scale = float whole_weight_scale =
static_cast<float>(range * range) / max_range / range; static_cast<float>(range * range) / max_range / range;
...@@ -163,7 +215,9 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -163,7 +215,9 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
weight_scale.push_back(whole_weight_scale); weight_scale.push_back(whole_weight_scale);
} }
op_desc.SetAttr("enable_int8", true); op_desc.SetAttr("enable_int8", true);
if (quantized_op->stmt()->op_info()->HasAttr("input_scale")) {
op_desc.SetAttr("input_scale", input_scale); op_desc.SetAttr("input_scale", input_scale);
}
op_desc.SetAttr("weight_scale", weight_scale); op_desc.SetAttr("weight_scale", weight_scale);
// change the weight from the float type to int8 type. // change the weight from the float type to int8 type.
...@@ -464,6 +518,192 @@ cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -464,6 +518,192 @@ cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc; cpp::OpDesc op_desc;
return op_desc; return op_desc;
} }
// ================dynamic quant fuse==============
// #define DYNAMIC_RANGE
void DynamicQuantDequantOpFuser::BuildPattern() {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
std::string weight_name = "";
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
weight_name = "Filter";
} else {
weight_name = "Y";
}
auto* quant_op_input = VarNode("quant_op_input")
->assert_is_op_input(quant_type_, "X")
->AsInput();
#ifdef DYNAMIC_RANGE
auto* quant_op_in_scale = VarNode("quant_op_in_scale")
->assert_is_op_input(quant_type_, "InScale")
->AsIntermediate();
#endif
auto* quant_op = OpNode("quant_op", quant_type_)
->assert_is_op(quant_type_)
->AsIntermediate();
auto* quant_op_out_scale =
VarNode("quant_op_out_scale")
->assert_is_op_output(quant_type_, "OutScale")
->assert_is_op_input("fake_dequantize_max_abs", "Scale")
->AsIntermediate();
auto* quant_op_out = VarNode("quant_op_out")
->assert_is_op_output(quant_type_, "Out")
->assert_is_op_input(op_type_)
->AsIntermediate();
std::vector<PMNode*> nodes;
for (int i = 0; i < times_; i++) {
nodes.push_back(VarNode(string_format("quantized_op_weight%d", i))
->assert_is_op_input(op_type_, weight_name)
->AsInput());
nodes.push_back(OpNode(string_format("quantized_op%d", i), op_type_)
->assert_is_op(op_type_)
->AsIntermediate());
nodes.push_back(VarNode(string_format("quantized_op_out%d", i))
->assert_is_op_output(op_type_)
->assert_is_op_input("fake_dequantize_max_abs", "X")
->AsIntermediate());
nodes.push_back(
OpNode(string_format("dequant_op%d", i), "fake_dequantize_max_abs")
->assert_is_op("fake_dequantize_max_abs")
->AsIntermediate());
nodes.push_back(VarNode(string_format("dequant_op_out%d", i))
->assert_is_op_output("fake_dequantize_max_abs", "Out")
->AsOutput());
}
#ifdef DYNAMIC_RANGE
quant_op->LinksFrom({quant_op_input, quant_op_in_scale});
#endif
quant_op->LinksFrom({quant_op_input});
quant_op_out->LinksFrom({quant_op});
quant_op_out_scale->LinksFrom({quant_op});
for (int i = 0; i < times_; i++) {
nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom(
{quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]});
nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOffset]});
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kDequantOpOffset]});
}
}
void DynamicQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
auto* quant_op_input = matched.at("quant_op_input");
#ifdef DYNAMIC_RANGE
auto* quant_op_in_scale = matched.at("quant_op_in_scale");
#endif
auto* quant_op = matched.at("quant_op");
std::vector<Node*> nodes;
for (int i = 0; i < times_; i++) {
nodes.push_back(matched.at(string_format("quantized_op_weight%d", i)));
nodes.push_back(matched.at(string_format("quantized_op%d", i)));
nodes.push_back(matched.at(string_format("quantized_op_out%d", i)));
nodes.push_back(matched.at(string_format("dequant_op%d", i)));
nodes.push_back(matched.at(string_format("dequant_op_out%d", i)));
}
int bit_length = quant_op->stmt()->op_info()->GetAttr<int>("bit_length");
auto* scope = quant_op->stmt()->op()->scope();
auto& valid_places = quant_op->stmt()->op()->valid_places();
int range = ((1 << (bit_length - 1)) - 1);
#ifdef DYNAMIC_RANGE
auto input_scale_t = scope->FindVar(quant_op_in_scale->arg()->name)
->GetMutable<lite::Tensor>();
float input_scale = input_scale_t->data<float>()[0] / range;
VLOG(4) << "range: " << range << " input_scale: " << input_scale;
#endif
for (int i = 0; i < times_; i++) {
float max_range = nodes[i * kNumFields + kDequantOpOffset]
->stmt()
->op_info()
->GetAttr<float>("max_range");
// weight_scale = max(abs(weight))
float whole_weight_scale =
static_cast<float>(range * range) / max_range / range;
cpp::OpDesc op_desc =
*nodes[i * kNumFields + kQuantizedOpOffset]->stmt()->op_info();
auto quantized_weight_var_name =
nodes[i * kNumFields + kQuantizedWeightOffset]->arg()->name;
auto quantized_weight_t =
scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>();
std::vector<float> weight_scale;
int weight_scale_size;
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
op_desc.SetInput("Input", {matched.at("quant_op_input")->arg()->name});
op_desc.SetOutput(
"Output", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name});
// Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should
// be Cout.
weight_scale_size = quantized_weight_t->dims()[0];
} else if (op_type_ == "mul") {
op_desc.SetInput("X", {matched.at("quant_op_input")->arg()->name});
op_desc.SetOutput(
"Out", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name});
// Fc weight: Cin * Cout, the weight_scale_size should be Cout.
weight_scale_size = quantized_weight_t->dims()[1];
}
for (int i = 0; i < weight_scale_size; i++) {
weight_scale.push_back(whole_weight_scale);
}
// op_desc.SetAttr("enable_int8", true);
// op_desc.SetAttr("input_scale", input_scale);
op_desc.SetAttr("weight_scale", weight_scale);
Tensor temp_tensor;
temp_tensor.CopyDataFrom(*quantized_weight_t);
float* temp_data = temp_tensor.mutable_data<float>();
size_t weight_num = quantized_weight_t->data_size();
quantized_weight_t->set_persistable(true);
#ifdef LITE_WITH_FPGA
float* quantized_weight_data = quantized_weight_t->mutable_data<float>();
for (size_t i = 0; i < weight_num; i++) {
quantized_weight_data[i] = temp_data[i] * whole_weight_scale;
}
quantized_weight_t->set_precision(PRECISION(kFloat));
#else
int8_t* quantized_weight_data = quantized_weight_t->mutable_data<int8_t>();
for (size_t i = 0; i < weight_num; i++) {
quantized_weight_data[i] = static_cast<int8_t>(temp_data[i]);
}
quantized_weight_t->set_precision(PRECISION(kInt8));
#endif
auto quantized_op = LiteOpRegistry::Global().Create(op_type_);
quantized_op->Attach(op_desc, scope);
auto* new_op_node =
graph->GraphCreateInstructNode(quantized_op, valid_places);
IR_NODE_LINK_TO(quant_op_input, new_op_node);
IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset],
new_op_node);
IR_NODE_LINK_TO(new_op_node, nodes[i * kNumFields + kDequantOpOutOffset]);
}
}
cpp::OpDesc DynamicQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
} // namespace fusion } // namespace fusion
} // namespace mir } // namespace mir
......
...@@ -52,6 +52,19 @@ class DeleteQuantOpFuser : public FuseBase { ...@@ -52,6 +52,19 @@ class DeleteQuantOpFuser : public FuseBase {
private: private:
std::string quant_op_type_{}; std::string quant_op_type_{};
}; };
class DeleteDynamicQuantOpFuser : public FuseBase {
public:
explicit DeleteDynamicQuantOpFuser(const std::string& quant_op_type)
: quant_op_type_(quant_op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string quant_op_type_{};
};
/* DequantOpFuser process conv2d/depthwise_conv2d/mul + fake_dequantize_max_abs. /* DequantOpFuser process conv2d/depthwise_conv2d/mul + fake_dequantize_max_abs.
*/ */
...@@ -106,6 +119,24 @@ class DeleteQuantDequantOpFuser : public FuseBase { ...@@ -106,6 +119,24 @@ class DeleteQuantDequantOpFuser : public FuseBase {
private: private:
std::string quantized_op_type_{}; std::string quantized_op_type_{};
}; };
// dynamic quantdequant op fuser
class DynamicQuantDequantOpFuser : public FuseBase {
public:
explicit DynamicQuantDequantOpFuser(const std::string& quantized_op_type,
const std::string& op_type,
int i)
: op_type_(op_type), quant_type_(quantized_op_type), times_(i) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string op_type_{};
std::string quant_type_{};
int times_{1};
};
} // namespace fusion } // namespace fusion
} // namespace mir } // namespace mir
......
...@@ -23,3 +23,5 @@ namespace operators {} // namespace operators ...@@ -23,3 +23,5 @@ namespace operators {} // namespace operators
REGISTER_LITE_OP(fake_quantize_range_abs_max, REGISTER_LITE_OP(fake_quantize_range_abs_max,
paddle::lite::operators::FakeQuantizeRangeMaxAbsOpLite); paddle::lite::operators::FakeQuantizeRangeMaxAbsOpLite);
REGISTER_LITE_OP(fake_quantize_abs_max,
paddle::lite::operators::FakeQuantizeRangeMaxAbsOpLite);
...@@ -40,13 +40,15 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite { ...@@ -40,13 +40,15 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite {
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
if (op_desc.HasInput("InScale")) {
auto in_scale = op_desc.Input("InScale").front(); auto in_scale = op_desc.Input("InScale").front();
param_.in_scale = scope->FindVar(in_scale)->GetMutable<lite::Tensor>();
}
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
auto out_scale = op_desc.Output("OutScale").front(); auto out_scale = op_desc.Output("OutScale").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>(); param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.in_scale = scope->FindVar(in_scale)->GetMutable<lite::Tensor>();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>(); param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.out_scale = scope->FindVar(out_scale)->GetMutable<lite::Tensor>(); param_.out_scale = scope->FindVar(out_scale)->GetMutable<lite::Tensor>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册