From 735864a0560d30a2e657fb987a9164d826a0fc5e Mon Sep 17 00:00:00 2001 From: MyPandaShaoxiang Date: Mon, 6 Jan 2020 05:11:20 -0500 Subject: [PATCH] feat: add dyanmic quant fuse pass --- .../mir/fusion/quant_dequant_fuse_pass.cc | 16 +- .../core/mir/fusion/quant_dequant_op_fuser.cc | 246 +++++++++++++++++- lite/core/mir/fusion/quant_dequant_op_fuser.h | 31 +++ lite/operators/fake_quantize_range_abs_max.cc | 2 + lite/operators/fake_quantize_range_abs_max.h | 6 +- 5 files changed, 295 insertions(+), 6 deletions(-) diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index ff5a7a1f25..2720404fb0 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -27,10 +27,24 @@ namespace mir { void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { // delete quant node std::vector quant_op_types = { - "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; + "fake_quantize_abs_max", + "fake_quantize_range_abs_max", + "fake_quantize_moving_average_abs_max"}; + /* + for (auto& op_type : {"conv2d", "mul", "depthwise_conv2d"}) { + for (int i = 5; i >= 1; --i){ + fusion::DynamicQuantDequantOpFuser fuser("fake_quantize_abs_max", op_type, + i); + fuser(graph.get()); + } + } + */ + for (auto& op_type : quant_op_types) { fusion::DeleteQuantOpFuser fuser(op_type); fuser(graph.get()); + fusion::DeleteDynamicQuantOpFuser dfuser(op_type); + dfuser(graph.get()); } // fuse quantized node and dequant node diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index da611e4490..9c5bb24d88 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -77,6 +77,55 @@ cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) { return op_desc; } +void DeleteDynamicQuantOpFuser::BuildPattern() { + auto* input_act_node = + VarNode("input_act_node")->assert_is_op_input(quant_op_type_, "X"); + auto* quant_node = + OpNode("quant_node", quant_op_type_)->assert_is_op(quant_op_type_); + auto* output_scale_node = + VarNode("output_scale_node") + ->assert_is_op_output(quant_op_type_, "OutScale"); + auto* output_act_node = + VarNode("output_act_node")->assert_is_op_output(quant_op_type_, "Out"); + + quant_node->LinksFrom({input_act_node}); + output_scale_node->LinksFrom({quant_node}); + output_act_node->LinksFrom({quant_node}); + VLOG(4) << "DeleteQuantOpFuser BuildPattern quant_op_type:" << quant_op_type_; +} + +void DeleteDynamicQuantOpFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto* input_act_node = matched.at("input_act_node"); + auto* quant_node = matched.at("quant_node"); + auto* output_scale_node = matched.at("output_scale_node"); + auto* output_act_node = matched.at("output_act_node"); + + // obtain values, save values and relink node + int bit_length = quant_node->stmt()->op_info()->GetAttr("bit_length"); + int range = ((1 << (bit_length - 1)) - 1); + auto* scope = quant_node->stmt()->op()->scope(); + auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name) + ->GetMutable(); + float scale_value = scale_tensor->data()[0] / range; + + auto outlinks = output_act_node->outlinks; + for (auto* quantized_node : outlinks) { + auto* op_desc = quantized_node->stmt()->mutable_op_info(); + op_desc->SetAttr("bit_length", bit_length); + IR_NODE_LINK_TO(input_act_node, quantized_node) + } + + // delete nodes and edges + std::unordered_set nodes2rm = { + quant_node, output_scale_node, output_act_node}; + GraphSafeRemoveNodes(graph, nodes2rm); +} + +cpp::OpDesc DeleteDynamicQuantOpFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + return op_desc; +} void DequantOpFuser::BuildPattern() { std::string weight_name = ""; if (quantized_op_type_ == "conv2d" || @@ -130,8 +179,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, auto& valid_places = quantized_op->stmt()->op()->valid_places(); int bit_length = quantized_op->stmt()->op_info()->GetAttr("bit_length"); int range = ((1 << (bit_length - 1)) - 1); - float input_scale = - quantized_op->stmt()->op_info()->GetAttr("input_scale"); + float input_scale = 0; + if (quantized_op->stmt()->op_info()->HasAttr("input_scale")) { + input_scale = + quantized_op->stmt()->op_info()->GetAttr("input_scale"); + } float max_range = dequant_op->stmt()->op_info()->GetAttr("max_range"); float whole_weight_scale = static_cast(range * range) / max_range / range; @@ -163,7 +215,9 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, weight_scale.push_back(whole_weight_scale); } op_desc.SetAttr("enable_int8", true); - op_desc.SetAttr("input_scale", input_scale); + if (quantized_op->stmt()->op_info()->HasAttr("input_scale")) { + op_desc.SetAttr("input_scale", input_scale); + } op_desc.SetAttr("weight_scale", weight_scale); // change the weight from the float type to int8 type. @@ -464,6 +518,192 @@ cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc op_desc; return op_desc; } +// ================dynamic quant fuse============== +// #define DYNAMIC_RANGE +void DynamicQuantDequantOpFuser::BuildPattern() { + const int kNumFields = 5; + const int kQuantizedWeightOffset = 0; + const int kQuantizedOpOffset = 1; + const int kQuantizedOpOutOffset = 2; + const int kDequantOpOffset = 3; + const int kDequantOpOutOffset = 4; + + std::string weight_name = ""; + if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { + weight_name = "Filter"; + } else { + weight_name = "Y"; + } + auto* quant_op_input = VarNode("quant_op_input") + ->assert_is_op_input(quant_type_, "X") + ->AsInput(); +#ifdef DYNAMIC_RANGE + auto* quant_op_in_scale = VarNode("quant_op_in_scale") + ->assert_is_op_input(quant_type_, "InScale") + ->AsIntermediate(); +#endif + auto* quant_op = OpNode("quant_op", quant_type_) + ->assert_is_op(quant_type_) + ->AsIntermediate(); + + auto* quant_op_out_scale = + VarNode("quant_op_out_scale") + ->assert_is_op_output(quant_type_, "OutScale") + ->assert_is_op_input("fake_dequantize_max_abs", "Scale") + ->AsIntermediate(); + + auto* quant_op_out = VarNode("quant_op_out") + ->assert_is_op_output(quant_type_, "Out") + ->assert_is_op_input(op_type_) + ->AsIntermediate(); + std::vector nodes; + for (int i = 0; i < times_; i++) { + nodes.push_back(VarNode(string_format("quantized_op_weight%d", i)) + ->assert_is_op_input(op_type_, weight_name) + ->AsInput()); + + nodes.push_back(OpNode(string_format("quantized_op%d", i), op_type_) + ->assert_is_op(op_type_) + ->AsIntermediate()); + + nodes.push_back(VarNode(string_format("quantized_op_out%d", i)) + ->assert_is_op_output(op_type_) + ->assert_is_op_input("fake_dequantize_max_abs", "X") + ->AsIntermediate()); + + nodes.push_back( + OpNode(string_format("dequant_op%d", i), "fake_dequantize_max_abs") + ->assert_is_op("fake_dequantize_max_abs") + ->AsIntermediate()); + nodes.push_back(VarNode(string_format("dequant_op_out%d", i)) + ->assert_is_op_output("fake_dequantize_max_abs", "Out") + ->AsOutput()); + } + +#ifdef DYNAMIC_RANGE + quant_op->LinksFrom({quant_op_input, quant_op_in_scale}); +#endif + quant_op->LinksFrom({quant_op_input}); + quant_op_out->LinksFrom({quant_op}); + quant_op_out_scale->LinksFrom({quant_op}); + for (int i = 0; i < times_; i++) { + nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom( + {quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]}); + nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom( + {nodes[i * kNumFields + kQuantizedOpOffset]}); + nodes[i * kNumFields + kDequantOpOffset]->LinksFrom( + {nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale}); + nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom( + {nodes[i * kNumFields + kDequantOpOffset]}); + } +} + +void DynamicQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + const int kNumFields = 5; + const int kQuantizedWeightOffset = 0; + const int kQuantizedOpOffset = 1; + const int kDequantOpOffset = 3; + const int kDequantOpOutOffset = 4; + + auto* quant_op_input = matched.at("quant_op_input"); +#ifdef DYNAMIC_RANGE + auto* quant_op_in_scale = matched.at("quant_op_in_scale"); +#endif + auto* quant_op = matched.at("quant_op"); + + std::vector nodes; + for (int i = 0; i < times_; i++) { + nodes.push_back(matched.at(string_format("quantized_op_weight%d", i))); + nodes.push_back(matched.at(string_format("quantized_op%d", i))); + nodes.push_back(matched.at(string_format("quantized_op_out%d", i))); + nodes.push_back(matched.at(string_format("dequant_op%d", i))); + nodes.push_back(matched.at(string_format("dequant_op_out%d", i))); + } + int bit_length = quant_op->stmt()->op_info()->GetAttr("bit_length"); + auto* scope = quant_op->stmt()->op()->scope(); + auto& valid_places = quant_op->stmt()->op()->valid_places(); + int range = ((1 << (bit_length - 1)) - 1); + +#ifdef DYNAMIC_RANGE + auto input_scale_t = scope->FindVar(quant_op_in_scale->arg()->name) + ->GetMutable(); + float input_scale = input_scale_t->data()[0] / range; + VLOG(4) << "range: " << range << " input_scale: " << input_scale; +#endif + for (int i = 0; i < times_; i++) { + float max_range = nodes[i * kNumFields + kDequantOpOffset] + ->stmt() + ->op_info() + ->GetAttr("max_range"); + // weight_scale = max(abs(weight)) + float whole_weight_scale = + static_cast(range * range) / max_range / range; + + cpp::OpDesc op_desc = + *nodes[i * kNumFields + kQuantizedOpOffset]->stmt()->op_info(); + + auto quantized_weight_var_name = + nodes[i * kNumFields + kQuantizedWeightOffset]->arg()->name; + auto quantized_weight_t = + scope->FindVar(quantized_weight_var_name)->GetMutable(); + std::vector weight_scale; + int weight_scale_size; + + if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { + op_desc.SetInput("Input", {matched.at("quant_op_input")->arg()->name}); + op_desc.SetOutput( + "Output", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); + // Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should + // be Cout. + weight_scale_size = quantized_weight_t->dims()[0]; + } else if (op_type_ == "mul") { + op_desc.SetInput("X", {matched.at("quant_op_input")->arg()->name}); + op_desc.SetOutput( + "Out", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); + // Fc weight: Cin * Cout, the weight_scale_size should be Cout. + weight_scale_size = quantized_weight_t->dims()[1]; + } + for (int i = 0; i < weight_scale_size; i++) { + weight_scale.push_back(whole_weight_scale); + } + // op_desc.SetAttr("enable_int8", true); + // op_desc.SetAttr("input_scale", input_scale); + op_desc.SetAttr("weight_scale", weight_scale); + + Tensor temp_tensor; + temp_tensor.CopyDataFrom(*quantized_weight_t); + float* temp_data = temp_tensor.mutable_data(); + size_t weight_num = quantized_weight_t->data_size(); + quantized_weight_t->set_persistable(true); +#ifdef LITE_WITH_FPGA + float* quantized_weight_data = quantized_weight_t->mutable_data(); + for (size_t i = 0; i < weight_num; i++) { + quantized_weight_data[i] = temp_data[i] * whole_weight_scale; + } + quantized_weight_t->set_precision(PRECISION(kFloat)); +#else + int8_t* quantized_weight_data = quantized_weight_t->mutable_data(); + for (size_t i = 0; i < weight_num; i++) { + quantized_weight_data[i] = static_cast(temp_data[i]); + } + quantized_weight_t->set_precision(PRECISION(kInt8)); +#endif + auto quantized_op = LiteOpRegistry::Global().Create(op_type_); + quantized_op->Attach(op_desc, scope); + auto* new_op_node = + graph->GraphCreateInstructNode(quantized_op, valid_places); + IR_NODE_LINK_TO(quant_op_input, new_op_node); + IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset], + new_op_node); + IR_NODE_LINK_TO(new_op_node, nodes[i * kNumFields + kDequantOpOutOffset]); + } +} + +cpp::OpDesc DynamicQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + return op_desc; +} } // namespace fusion } // namespace mir diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.h b/lite/core/mir/fusion/quant_dequant_op_fuser.h index bef9f4d957..c21df350f9 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.h +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.h @@ -52,6 +52,19 @@ class DeleteQuantOpFuser : public FuseBase { private: std::string quant_op_type_{}; }; +class DeleteDynamicQuantOpFuser : public FuseBase { + public: + explicit DeleteDynamicQuantOpFuser(const std::string& quant_op_type) + : quant_op_type_(quant_op_type) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + + private: + std::string quant_op_type_{}; +}; /* DequantOpFuser process conv2d/depthwise_conv2d/mul + fake_dequantize_max_abs. */ @@ -106,6 +119,24 @@ class DeleteQuantDequantOpFuser : public FuseBase { private: std::string quantized_op_type_{}; }; +// dynamic quantdequant op fuser +class DynamicQuantDequantOpFuser : public FuseBase { + public: + explicit DynamicQuantDequantOpFuser(const std::string& quantized_op_type, + const std::string& op_type, + int i) + : op_type_(op_type), quant_type_(quantized_op_type), times_(i) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + + private: + std::string op_type_{}; + std::string quant_type_{}; + int times_{1}; +}; } // namespace fusion } // namespace mir diff --git a/lite/operators/fake_quantize_range_abs_max.cc b/lite/operators/fake_quantize_range_abs_max.cc index a8ce3f75a5..ebf7e41f4b 100644 --- a/lite/operators/fake_quantize_range_abs_max.cc +++ b/lite/operators/fake_quantize_range_abs_max.cc @@ -23,3 +23,5 @@ namespace operators {} // namespace operators REGISTER_LITE_OP(fake_quantize_range_abs_max, paddle::lite::operators::FakeQuantizeRangeMaxAbsOpLite); +REGISTER_LITE_OP(fake_quantize_abs_max, + paddle::lite::operators::FakeQuantizeRangeMaxAbsOpLite); diff --git a/lite/operators/fake_quantize_range_abs_max.h b/lite/operators/fake_quantize_range_abs_max.h index 726731595a..f68d1e20f6 100644 --- a/lite/operators/fake_quantize_range_abs_max.h +++ b/lite/operators/fake_quantize_range_abs_max.h @@ -40,13 +40,15 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); - auto in_scale = op_desc.Input("InScale").front(); + if (op_desc.HasInput("InScale")) { + auto in_scale = op_desc.Input("InScale").front(); + param_.in_scale = scope->FindVar(in_scale)->GetMutable(); + } auto out = op_desc.Output("Out").front(); auto out_scale = op_desc.Output("OutScale").front(); param_.x = scope->FindVar(x)->GetMutable(); - param_.in_scale = scope->FindVar(in_scale)->GetMutable(); param_.out = scope->FindVar(out)->GetMutable(); param_.out_scale = scope->FindVar(out_scale)->GetMutable(); -- GitLab