From 253acb801a59aa8aecc92b65746d9ab64101e6d5 Mon Sep 17 00:00:00 2001 From: juncaipeng <52520497+juncaipeng@users.noreply.github.com> Date: Mon, 14 Oct 2019 10:51:40 +0800 Subject: [PATCH] Optimize quant_dequant_fuse_pass (#2169) * optimize quant_dequant_fuse_pass, test=develop --- lite/api/mobilenetv1_int8_test.cc | 15 ++ .../mir/fusion/quant_dequant_fuse_pass.cc | 58 ++++- .../core/mir/fusion/quant_dequant_op_fuser.cc | 238 +++++++----------- lite/core/mir/fusion/quant_dequant_op_fuser.h | 10 +- lite/operators/op_params.h | 3 +- 5 files changed, 159 insertions(+), 165 deletions(-) diff --git a/lite/api/mobilenetv1_int8_test.cc b/lite/api/mobilenetv1_int8_test.cc index 472e2a2595..d3ac115fa2 100644 --- a/lite/api/mobilenetv1_int8_test.cc +++ b/lite/api/mobilenetv1_int8_test.cc @@ -74,6 +74,21 @@ void TestModel(const std::vector& valid_places, 1e-6); } } + + auto* out_data = out->data(); + LOG(INFO) << "output data:"; + for (int i = 0; i < out->numel(); i += step) { + LOG(INFO) << out_data[i]; + } + float max_val = out_data[0]; + int max_val_arg = 0; + for (int i = 1; i < out->numel(); i++) { + if (max_val < out_data[i]) { + max_val = out_data[i]; + max_val_arg = i; + } + } + LOG(INFO) << "max val:" << max_val << ", max_val_arg:" << max_val_arg; } TEST(MobileNetV1, test_arm) { diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index 15fdff5edf..ecf2508700 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "lite/core/mir/fusion/quant_dequant_fuse_pass.h" +#include #include +#include #include #include "lite/api/paddle_place.h" #include "lite/core/mir/fusion/quant_dequant_op_fuser.h" @@ -24,18 +26,60 @@ namespace lite { namespace mir { void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { + // obtain useful values and save to quantized_node, remove quant_nodes and + // releated nodes std::unordered_set quant_types = { "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; - std::unordered_set quantized_op_types = { - "conv2d", "mul", "depthwise_conv2d"}; - for (auto& quant_type : quant_types) { - for (auto& op_type : quantized_op_types) { - for (int i = 6; i >= 1; i--) { - fusion::QuantDequantOpFuser fuser(op_type, quant_type, i); - fuser(graph.get()); + for (auto& cur_node : graph->mutable_nodes()) { + if (cur_node.IsStmt() && quant_types.count(cur_node.stmt()->op_type())) { + // find input nodes and output nodes + std::list input_nodes = cur_node.inlinks; + std::list output_nodes = cur_node.outlinks; + CHECK_EQ(input_nodes.size(), 2); + CHECK_EQ(output_nodes.size(), 2); + + bool front_is_scale = input_nodes.front()->arg()->is_weight; + Node* input_scale_node = + front_is_scale ? input_nodes.front() : input_nodes.back(); + Node* input_act_node = + front_is_scale ? input_nodes.back() : input_nodes.front(); + front_is_scale = output_nodes.front()->arg()->is_weight; + Node* output_scale_node = + front_is_scale ? output_nodes.front() : output_nodes.back(); + Node* output_act_node = + front_is_scale ? output_nodes.back() : output_nodes.front(); + + // relink nodes and save value to quantized_node + int bit_length = cur_node.stmt()->op_info()->GetAttr("bit_length"); + int range = ((1 << (bit_length - 1)) - 1); + auto* scope = cur_node.stmt()->op()->scope(); + auto scale_tensor = scope->FindVar(output_scale_node->arg()->name) + ->GetMutable(); + float scale_value = scale_tensor->data()[0] / range; + + for (auto* quantized_node_ptr : output_act_node->outlinks) { + quantized_node_ptr->stmt()->mutable_op_info()->SetAttr( + "bit_length", bit_length); + quantized_node_ptr->stmt()->mutable_op_info()->SetAttr( + "input_scale", scale_value); + IR_NODE_LINK_TO(input_act_node, quantized_node_ptr) + RemoveDirectedLink(output_act_node, quantized_node_ptr); } + + // delete nodes and edges + std::unordered_set nodes2rm = { + input_scale_node, &cur_node, output_scale_node, output_act_node}; + GraphSafeRemoveNodes(graph.get(), nodes2rm); } } + + // fuse quantized node and dequant node + std::unordered_set quantized_op_types = { + "conv2d", "mul", "depthwise_conv2d"}; + for (auto& op_type : quantized_op_types) { + fusion::QuantDequantOpFuser fuser(op_type); + fuser(graph.get()); + } } } // namespace mir diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index 1c7cf866b9..a0ede90446 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -23,170 +23,108 @@ namespace mir { namespace fusion { void QuantDequantOpFuser::BuildPattern() { - const int kNumFields = 5; - const int kQuantizedWeightOffset = 0; - const int kQuantizedOpOffset = 1; - const int kQuantizedOpOutOffset = 2; - const int kDequantOpOffset = 3; - const int kDequantOpOutOffset = 4; - std::string weight_name = ""; if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { weight_name = "Filter"; } else { weight_name = "Y"; } - auto* quant_op_input = VarNode("quant_op_input") - ->assert_is_op_input(quant_type_, "X") - ->AsInput(); - auto* quant_op_in_scale = VarNode("quant_op_in_scale") - ->assert_is_op_input(quant_type_, "InScale") - ->AsIntermediate(); - auto* quant_op = OpNode("quant_op", quant_type_) - ->assert_is_op(quant_type_) - ->AsIntermediate(); - - auto* quant_op_out_scale = - VarNode("quant_op_out_scale") - ->assert_is_op_output(quant_type_, "OutScale") - ->assert_is_op_input("fake_dequantize_max_abs", "Scale") - ->AsIntermediate(); - auto* quant_op_out = VarNode("quant_op_out") - ->assert_is_op_output(quant_type_, "Out") - ->assert_is_op_input(op_type_) + auto* quantized_op_input = + VarNode("quantized_op_input")->assert_is_op_input(op_type_)->AsInput(); + auto* quantized_op_weight = VarNode("quantized_op_weight") + ->assert_is_op_input(op_type_, weight_name) + ->AsInput(); + auto* quantized_op = OpNode("quantized_op", op_type_) + ->assert_is_op(op_type_) ->AsIntermediate(); - std::vector nodes; - for (int i = 0; i < times_; i++) { - nodes.push_back(VarNode(string_format("quantized_op_weight%d", i)) - ->assert_is_op_input(op_type_, weight_name) - ->AsInput()); - - nodes.push_back(OpNode(string_format("quantized_op%d", i), op_type_) - ->assert_is_op(op_type_) - ->AsIntermediate()); - - nodes.push_back(VarNode(string_format("quantized_op_out%d", i)) - ->assert_is_op_output(op_type_) - ->assert_is_op_input("fake_dequantize_max_abs", "X") - ->AsIntermediate()); - - nodes.push_back( - OpNode(string_format("dequant_op%d", i), "fake_dequantize_max_abs") - ->assert_is_op("fake_dequantize_max_abs") - ->AsIntermediate()); - nodes.push_back(VarNode(string_format("dequant_op_out%d", i)) - ->assert_is_op_output("fake_dequantize_max_abs", "Out") - ->AsOutput()); - } - - quant_op->LinksFrom({quant_op_input, quant_op_in_scale}); - quant_op_out->LinksFrom({quant_op}); - quant_op_out_scale->LinksFrom({quant_op}); - for (int i = 0; i < times_; i++) { - nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom( - {quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]}); - nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom( - {nodes[i * kNumFields + kQuantizedOpOffset]}); - nodes[i * kNumFields + kDequantOpOffset]->LinksFrom( - {nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale}); - nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom( - {nodes[i * kNumFields + kDequantOpOffset]}); - } + auto* quantized_op_out = + VarNode("quantized_op_out") + ->assert_is_op_output(op_type_) + ->assert_is_op_input("fake_dequantize_max_abs", "X") + ->AsIntermediate(); + auto* dequant_op = OpNode("dequant_op", "fake_dequantize_max_abs") + ->assert_is_op("fake_dequantize_max_abs") + ->AsIntermediate(); + auto* dequant_op_out = + VarNode("dequant_op_out") + ->assert_is_op_output("fake_dequantize_max_abs", "Out") + ->AsOutput(); + + quantized_op->LinksFrom({quantized_op_input, quantized_op_weight}); + quantized_op_out->LinksFrom({quantized_op}); + dequant_op->LinksFrom({quantized_op_out}); + dequant_op_out->LinksFrom({dequant_op}); } void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { - const int kNumFields = 5; - const int kQuantizedWeightOffset = 0; - const int kQuantizedOpOffset = 1; - const int kDequantOpOffset = 3; - const int kDequantOpOutOffset = 4; - - auto* quant_op_input = matched.at("quant_op_input"); - auto* quant_op_in_scale = matched.at("quant_op_in_scale"); - auto* quant_op = matched.at("quant_op"); - - std::vector nodes; - for (int i = 0; i < times_; i++) { - nodes.push_back(matched.at(string_format("quantized_op_weight%d", i))); - nodes.push_back(matched.at(string_format("quantized_op%d", i))); - nodes.push_back(matched.at(string_format("quantized_op_out%d", i))); - nodes.push_back(matched.at(string_format("dequant_op%d", i))); - nodes.push_back(matched.at(string_format("dequant_op_out%d", i))); - } - int bit_length = quant_op->stmt()->op_info()->GetAttr("bit_length"); - auto* scope = quant_op->stmt()->op()->scope(); - auto& valid_places = quant_op->stmt()->op()->valid_places(); + auto* quant_op_input = matched.at("quantized_op_input"); + auto* quantized_op_weight = matched.at("quantized_op_weight"); + auto* quantized_op = matched.at("quantized_op"); + auto* dequant_op = matched.at("dequant_op"); + auto* dequant_op_out = matched.at("dequant_op_out"); + + // obtain input_scale and weight_scale + auto* scope = quantized_op->stmt()->op()->scope(); + auto& valid_places = quantized_op->stmt()->op()->valid_places(); + int bit_length = quantized_op->stmt()->op_info()->GetAttr("bit_length"); int range = ((1 << (bit_length - 1)) - 1); - auto input_scale_t = scope->FindVar(quant_op_in_scale->arg()->name) - ->GetMutable(); - float input_scale = input_scale_t->data()[0] / range; - - VLOG(4) << "range: " << range << " input_scale: " << input_scale; - for (int i = 0; i < times_; i++) { - float max_range = nodes[i * kNumFields + kDequantOpOffset] - ->stmt() - ->op_info() - ->GetAttr("max_range"); - // weight_scale = max(abs(weight)) - float whole_weight_scale = - static_cast(range * range) / max_range / range; - - cpp::OpDesc op_desc = - *nodes[i * kNumFields + kQuantizedOpOffset]->stmt()->op_info(); - - auto quantized_weight_var_name = - nodes[i * kNumFields + kQuantizedWeightOffset]->arg()->name; - auto quantized_weight_t = - scope->FindVar(quantized_weight_var_name)->GetMutable(); - std::vector weight_scale; - int weight_scale_size; - - if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { - op_desc.SetInput("Input", {matched.at("quant_op_input")->arg()->name}); - op_desc.SetOutput( - "Output", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); - // Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should - // be Cout. - weight_scale_size = quantized_weight_t->dims()[0]; - } else if (op_type_ == "mul") { - op_desc.SetInput("X", {matched.at("quant_op_input")->arg()->name}); - op_desc.SetOutput( - "Out", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); - // Fc weight: Cin * Cout, the weight_scale_size should be Cout. - weight_scale_size = quantized_weight_t->dims()[1]; - } - for (int i = 0; i < weight_scale_size; i++) { - weight_scale.push_back(whole_weight_scale); - } - op_desc.SetAttr("enable_int8", true); - op_desc.SetAttr("input_scale", input_scale); - op_desc.SetAttr("weight_scale", weight_scale); - - Tensor temp_tensor; - temp_tensor.CopyDataFrom(*quantized_weight_t); - float* temp_data = temp_tensor.mutable_data(); - - size_t weight_num = quantized_weight_t->data_size(); - int8_t* quantized_weight_data = quantized_weight_t->mutable_data(); - - // change the weight from the float type to int8 type. - for (size_t i = 0; i < weight_num; i++) { - quantized_weight_data[i] = static_cast(temp_data[i]); - } - quantized_weight_t->set_persistable(true); - quantized_weight_t->set_precision(PRECISION(kInt8)); - auto quantized_op = LiteOpRegistry::Global().Create(op_type_); - - quantized_op->Attach(op_desc, scope); - auto* new_op_node = - graph->GraphCreateInstructNode(quantized_op, valid_places); - IR_NODE_LINK_TO(quant_op_input, new_op_node); - IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset], - new_op_node); - IR_NODE_LINK_TO(new_op_node, nodes[i * kNumFields + kDequantOpOutOffset]); + float input_scale = + quantized_op->stmt()->op_info()->GetAttr("input_scale"); + float max_range = dequant_op->stmt()->op_info()->GetAttr("max_range"); + float whole_weight_scale = + static_cast(range * range) / max_range / range; + // max_range = range * range / max(abs(weight)) + // weight_scale = range * range / (range * range / max(abs(weight))) / range + // = max(abs(weight)) / range + + // set op desc + cpp::OpDesc op_desc = *quantized_op->stmt()->op_info(); + auto quantized_weight_var_name = quantized_op_weight->arg()->name; + auto quantized_weight_t = + scope->FindVar(quantized_weight_var_name)->GetMutable(); + std::vector weight_scale; + int weight_scale_size; + if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { + op_desc.SetInput("Input", {quant_op_input->arg()->name}); + op_desc.SetOutput("Output", {dequant_op_out->arg()->name}); + // Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should + // be Cout. + weight_scale_size = quantized_weight_t->dims()[0]; + } else if (op_type_ == "mul") { + op_desc.SetInput("X", {quant_op_input->arg()->name}); + op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); + // Fc weight: Cin * Cout, the weight_scale_size should be Cout. + weight_scale_size = quantized_weight_t->dims()[1]; + } + for (int i = 0; i < weight_scale_size; i++) { + weight_scale.push_back(whole_weight_scale); + } + op_desc.SetAttr("enable_int8", true); + op_desc.SetAttr("input_scale", input_scale); + op_desc.SetAttr("weight_scale", weight_scale); + + // change the weight from the float type to int8 type. + Tensor temp_tensor; + temp_tensor.CopyDataFrom(*quantized_weight_t); + float* temp_data = temp_tensor.mutable_data(); + size_t weight_num = quantized_weight_t->data_size(); + int8_t* quantized_weight_data = quantized_weight_t->mutable_data(); + for (size_t i = 0; i < weight_num; i++) { + quantized_weight_data[i] = static_cast(temp_data[i]); } + quantized_weight_t->set_persistable(true); + quantized_weight_t->set_precision(PRECISION(kInt8)); + + // new op and relink nodes + auto new_quantized_op = LiteOpRegistry::Global().Create(op_type_); + new_quantized_op->Attach(op_desc, scope); + auto* new_quantized_op_node = + graph->GraphCreateInstructNode(new_quantized_op, valid_places); + IR_NODE_LINK_TO(quant_op_input, new_quantized_op_node); + IR_NODE_LINK_TO(quantized_op_weight, new_quantized_op_node); + IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out); } cpp::OpDesc QuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.h b/lite/core/mir/fusion/quant_dequant_op_fuser.h index 15833ad258..b635b58f2f 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.h +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.h @@ -37,10 +37,8 @@ namespace fusion { */ class QuantDequantOpFuser : public FuseBase { public: - explicit QuantDequantOpFuser(const std::string& op_type, - const std::string& quant_type, - int times) - : op_type_(op_type), quant_type_(quant_type), times_(times) {} + explicit QuantDequantOpFuser(const std::string& op_type) + : op_type_(op_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; @@ -48,9 +46,7 @@ class QuantDequantOpFuser : public FuseBase { cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; private: - std::string op_type_{"conv2d"}; - std::string quant_type_; - int times_; + std::string op_type_{}; }; } // namespace fusion diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 76cb4c2b23..119d4f11ea 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -35,7 +35,8 @@ using param_t = Any; bool enable_int8{false}; \ float input_scale{1.0}; \ std::vector weight_scale{}; \ - float output_scale{1.0}; + float output_scale{1.0}; \ + int bit_length{8}; /// ----------------------- Functional operators ------------------------------ struct FeedParam { -- GitLab