From 51e9898d6b9ce01dc679d99b562499d14e8e56e2 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Fri, 3 Apr 2020 11:40:22 +0800 Subject: [PATCH] Modify quant_dequant_fuse_pass to process quant_dequant_op, test=develop (#3341) --- lite/api/benchmark.cc | 20 +- lite/api/cxx_api.cc | 5 +- .../mir/fusion/quant_dequant_fuse_pass.cc | 8 +- .../core/mir/fusion/quant_dequant_op_fuser.cc | 260 ++++++------------ lite/core/mir/fusion/quant_dequant_op_fuser.h | 12 +- lite/core/op_lite.h | 26 ++ 6 files changed, 127 insertions(+), 204 deletions(-) diff --git a/lite/api/benchmark.cc b/lite/api/benchmark.cc index d53de7bf2e..0843faf0d6 100644 --- a/lite/api/benchmark.cc +++ b/lite/api/benchmark.cc @@ -44,7 +44,10 @@ DEFINE_string(input_shape, "set input shapes according to the model, " "separated by colon and comma, " "such as 1,3,244,244"); -DEFINE_string(input_img_path, "", "the path of input image"); +DEFINE_string(input_img_path, + "", + "the path of input image, if not set " + "input_img_path, the input of model will be 1.0."); DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(repeats, 1, "repeats times"); DEFINE_int32(power_mode, @@ -57,16 +60,11 @@ DEFINE_int32(power_mode, DEFINE_int32(threads, 1, "threads num"); DEFINE_string(result_filename, "result.txt", - "save benchmark " - "result to the file"); + "save the inference time to the file."); DEFINE_bool(run_model_optimize, false, "if set true, apply model_optimize_tool to " "model and use optimized model to test. "); -DEFINE_bool(is_quantized_model, - false, - "if set true, " - "test the performance of the quantized model. "); namespace paddle { namespace lite_api { @@ -87,10 +85,6 @@ void OutputOptModel(const std::string& save_optimized_model_dir) { std::vector vaild_places = { Place{TARGET(kARM), PRECISION(kFloat)}, }; - if (FLAGS_is_quantized_model) { - vaild_places.insert(vaild_places.begin(), - Place{TARGET(kARM), PRECISION(kInt8)}); - } config.set_valid_places(vaild_places); auto predictor = lite_api::CreatePaddlePredictor(config); @@ -181,8 +175,8 @@ void Run(const std::vector& input_shape, int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); - if (FLAGS_model_dir == "" || FLAGS_result_filename == "") { - LOG(INFO) << "please run ./benchmark_bin --help to obtain usage."; + if (FLAGS_model_dir == "") { + LOG(INFO) << "Please run ./benchmark_bin --help to obtain usage."; exit(0); } diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 556a9e0af0..6a14b807cf 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -295,6 +295,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc, inner_places.emplace_back( TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + // Analysis whether the modle is quantized. + // For quantized model, add place(arm, int8) to inner_places const std::vector quant_dequant_op = { "fake_quantize_abs_max", "fake_quantize_range_abs_max", @@ -317,7 +319,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc, } } if (is_quantized_model) { - inner_places.emplace_back(Place{TARGET(kARM), PRECISION(kInt8)}); + inner_places.insert(inner_places.begin(), + Place{TARGET(kARM), PRECISION(kInt8)}); } Program program(desc, scope_, inner_places); diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index ab81f3d809..80a033c75f 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -44,11 +44,9 @@ void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { fuser(graph.get()); } - // delete quant_dequant_node - for (auto op_type : {"pool2d", "softmax", "elementwise_add"}) { - fusion::DeleteQuantDequantOpFuser fuser(op_type); - fuser(graph.get()); - } + // process quant_dequant_node + fusion::DeleteQuantDequantOpFuser dqd_fuser; + dqd_fuser(graph.get()); } } // namespace mir diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index 7797864a2e..a3a98b871f 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -50,7 +50,7 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph, auto* output_scale_node = matched.at("output_scale_node"); auto* output_act_node = matched.at("output_act_node"); - // obtain values, save values and relink node + // obtain scale, save attrs and relink node int bit_length = quant_node->stmt()->op_info()->GetAttr("bit_length"); int range = ((1 << (bit_length - 1)) - 1); auto* scope = quant_node->stmt()->op()->scope(); @@ -58,11 +58,22 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph, ->GetMutable(); float scale_value = scale_tensor->data()[0] / range; + auto in_act_name = input_act_node->arg()->name; + auto out_act_name = output_act_node->arg()->name; auto outlinks = output_act_node->outlinks; for (auto* quantized_node : outlinks) { - auto* op_desc = quantized_node->stmt()->mutable_op_info(); - op_desc->SetAttr("bit_length", bit_length); - op_desc->SetAttr("input_scale", scale_value); + // save input scale in quantized op by input argname + index + auto op_desc = *quantized_node->stmt()->mutable_op_info(); + std::string argname; + int index; + op_desc.GetInputArgname(out_act_name, &argname); + op_desc.GetInputIndex(out_act_name, &index); + op_desc.SetAttr(argname + std::to_string(index) + "_input_scale", + scale_value); + op_desc.SetAttr("input_scale", scale_value); // save it for now + op_desc.SetAttr("bit_length", bit_length); + op_desc.UpdateAllInputs(out_act_name, in_act_name); + quantized_node->stmt()->ResetOp(op_desc, graph->valid_places()); IR_NODE_LINK_TO(input_act_node, quantized_node) } @@ -125,19 +136,18 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, auto* dequant_op = matched.at("dequant_op"); auto* dequant_op_out = matched.at("dequant_op_out"); - // obtain input_scale and weight_scale + // obtain weight_scale from max_range auto* scope = quantized_op->stmt()->op()->scope(); auto& valid_places = quantized_op->stmt()->op()->valid_places(); int bit_length = quantized_op->stmt()->op_info()->GetAttr("bit_length"); int range = ((1 << (bit_length - 1)) - 1); - float input_scale = - quantized_op->stmt()->op_info()->GetAttr("input_scale"); float max_range = dequant_op->stmt()->op_info()->GetAttr("max_range"); float whole_weight_scale = static_cast(range * range) / max_range / range; - // max_range = range * range / max(abs(weight)) - // weight_scale = range * range / (range * range / max(abs(weight))) / range - // = max(abs(weight)) / range + // As: max_range = range * range / max(abs(weight)) + // So: whole_weight_scale + // = range * range / (range * range / max(abs(weight))) / range + // = max(abs(weight)) / range // set op desc cpp::OpDesc op_desc = *quantized_op->stmt()->op_info(); @@ -153,7 +163,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, // Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should // be Cout. weight_scale_size = quantized_weight_t->dims()[0]; - } else if (quantized_op_type_ == "mul") { + } else if (quantized_op_type_ == "mul" || quantized_op_type_ == "matmul") { op_desc.SetInput("X", {quantized_op_input->arg()->name}); op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); // Fc weight: Cin * Cout, the weight_scale_size should be Cout. @@ -163,7 +173,6 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, weight_scale.push_back(whole_weight_scale); } op_desc.SetAttr("enable_int8", true); - op_desc.SetAttr("input_scale", input_scale); op_desc.SetAttr("weight_scale", weight_scale); // change the weight from the float type to int8 type. @@ -209,6 +218,7 @@ void ChannelWiseDequantOpFuser::BuildPattern() { ->assert_is_op_output(quantized_op_type_) ->assert_is_op_input(dequant_op_type, "X") ->AsIntermediate(); + // The scale var_node of input activation is deleted in DeleteQuantOpFuser auto* dequant_op_channel_scale = VarNode("dequant_op_channel_scale") ->assert_is_op_input(dequant_op_type) ->AsIntermediate(); @@ -237,11 +247,9 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, auto* dequant_op = matched.at("dequant_op"); auto* dequant_op_out = matched.at("dequant_op_out"); - // obtain input_scale and weight_scale + // obtain input weight_scale from fake_dequant op auto* scope = quantized_op->stmt()->op()->scope(); auto& valid_places = quantized_op->stmt()->op()->valid_places(); - float input_scale = - quantized_op->stmt()->op_info()->GetAttr("input_scale"); std::vector weight_scale; std::vector quant_bits = @@ -258,11 +266,15 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, // set op desc cpp::OpDesc op_desc = *quantized_op->stmt()->op_info(); - op_desc.SetInput("Input", {quantized_op_input->arg()->name}); - op_desc.SetOutput("Output", {dequant_op_out->arg()->name}); - + if (quantized_op_type_ == "conv2d" || + quantized_op_type_ == "depthwise_conv2d") { + op_desc.SetInput("Input", {quantized_op_input->arg()->name}); + op_desc.SetOutput("Output", {dequant_op_out->arg()->name}); + } else if (quantized_op_type_ == "mul" || quantized_op_type_ == "matmul") { + op_desc.SetInput("X", {quantized_op_input->arg()->name}); + op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); + } op_desc.SetAttr("enable_int8", true); - op_desc.SetAttr("input_scale", input_scale); op_desc.SetAttr("weight_scale", weight_scale); // change the weight from the float type to int8 type. @@ -297,167 +309,65 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { void DeleteQuantDequantOpFuser::BuildPattern() { std::string quant_dequant_op_type = "fake_quantize_dequantize_moving_average_abs_max"; - if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") { - auto* input_scale_node = - VarNode("input_scale_node") - ->assert_is_op_input(quant_dequant_op_type, "InScale"); - auto* input_act_node = VarNode("input_act_node") - ->assert_is_op_input(quant_dequant_op_type, "X"); - auto* quant_dequant_node = - OpNode("quant_dequant_node", quant_dequant_op_type) - ->assert_is_op(quant_dequant_op_type); - auto* output_scale_node = - VarNode("output_scale_node") - ->assert_is_op_output(quant_dequant_op_type, "OutScale"); - auto* output_act_node = - VarNode("output_act_node") - ->assert_is_op_output(quant_dequant_op_type, "Out"); - auto* quantized_node = OpNode("quantized_node", quantized_op_type_) - ->assert_is_op(quantized_op_type_); - - quant_dequant_node->LinksFrom({input_scale_node, input_act_node}); - output_scale_node->LinksFrom({quant_dequant_node}); - output_act_node->LinksFrom({quant_dequant_node}); - quantized_node->LinksFrom({output_act_node}); - } else if (quantized_op_type_ == "elementwise_add") { - auto* input_scale_left_node = - VarNode("input_scale_left_node") - ->assert_is_op_input(quant_dequant_op_type, "InScale"); - auto* input_act_left_node = - VarNode("input_act_left_node") - ->assert_is_op_input(quant_dequant_op_type, "X"); - auto* quant_dequant_left_node = - OpNode("quant_dequant_left_node", quant_dequant_op_type) - ->assert_is_op(quant_dequant_op_type); - auto* output_scale_left_node = - VarNode("output_scale_left_node") - ->assert_is_op_output(quant_dequant_op_type, "OutScale"); - auto* output_act_left_node = - VarNode("output_act_left_node") - ->assert_is_op_output(quant_dequant_op_type, "Out") - ->assert_is_op_input(quantized_op_type_, "X"); - quant_dequant_left_node->LinksFrom( - {input_scale_left_node, input_act_left_node}); - output_scale_left_node->LinksFrom({quant_dequant_left_node}); - output_act_left_node->LinksFrom({quant_dequant_left_node}); - - auto* input_scale_right_node = - VarNode("input_scale_right_node") - ->assert_is_op_input(quant_dequant_op_type, "InScale"); - auto* input_act_right_node = - VarNode("input_act_right_node") - ->assert_is_op_input(quant_dequant_op_type, "X"); - auto* quant_dequant_right_node = - OpNode("quant_dequant_right_node", quant_dequant_op_type) - ->assert_is_op(quant_dequant_op_type); - auto* output_scale_right_node = - VarNode("output_scale_right_node") - ->assert_is_op_output(quant_dequant_op_type, "OutScale"); - auto* output_act_right_node = - VarNode("output_act_right_node") - ->assert_is_op_output(quant_dequant_op_type, "Out") - ->assert_is_op_input(quantized_op_type_, "Y"); - quant_dequant_right_node->LinksFrom( - {input_scale_right_node, input_act_right_node}); - output_scale_right_node->LinksFrom({quant_dequant_right_node}); - output_act_right_node->LinksFrom({quant_dequant_right_node}); - - auto* quantized_node = OpNode("quantized_node", quantized_op_type_) - ->assert_is_op(quantized_op_type_); - quantized_node->LinksFrom({output_act_left_node, output_act_right_node}); - } else { - LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_; - } - VLOG(4) << "DeleteQuantDequantOpFuser BuildPattern op_type:" - << quantized_op_type_; + auto* input_scale_node = + VarNode("input_scale_node") + ->assert_is_op_input(quant_dequant_op_type, "InScale"); + auto* input_act_node = + VarNode("input_act_node")->assert_is_op_input(quant_dequant_op_type, "X"); + auto* quant_dequant_node = OpNode("quant_dequant_node", quant_dequant_op_type) + ->assert_is_op(quant_dequant_op_type); + auto* output_scale_node = + VarNode("output_scale_node") + ->assert_is_op_output(quant_dequant_op_type, "OutScale"); + auto* output_act_node = + VarNode("output_act_node") + ->assert_is_op_output(quant_dequant_op_type, "Out"); + + quant_dequant_node->LinksFrom({input_scale_node, input_act_node}); + output_scale_node->LinksFrom({quant_dequant_node}); + output_act_node->LinksFrom({quant_dequant_node}); } void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { - if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") { - auto* input_scale_node = matched.at("input_scale_node"); - auto* input_act_node = matched.at("input_act_node"); - auto* quant_dequant_node = matched.at("quant_dequant_node"); - auto* output_scale_node = matched.at("output_scale_node"); - auto* output_act_node = matched.at("output_act_node"); - auto* quantized_node = matched.at("quantized_node"); - - // obtain values, save values and relink node - int bit_length = - quant_dequant_node->stmt()->op_info()->GetAttr("bit_length"); - int range = ((1 << (bit_length - 1)) - 1); - auto* scope = quant_dequant_node->stmt()->op()->scope(); - auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name) - ->GetMutable(); - float scale_value = scale_tensor->data()[0] / range; - - auto* op_desc = quantized_node->stmt()->mutable_op_info(); - op_desc->SetAttr("bit_length", bit_length); - op_desc->SetAttr("input_scale", scale_value); - op_desc->SetInput("X", {input_act_node->arg()->name}); - IR_NODE_LINK_TO(input_act_node, quantized_node) - auto update_op_desc = *quantized_node->stmt()->mutable_op_info(); - quantized_node->stmt()->ResetOp(update_op_desc, graph->valid_places()); - - // delete nodes and edges - std::unordered_set nodes2rm = {input_scale_node, - quant_dequant_node, - output_scale_node, - output_act_node}; - GraphSafeRemoveNodes(graph, nodes2rm); - } else if (quantized_op_type_ == "elementwise_add") { - auto* input_scale_left_node = matched.at("input_scale_left_node"); - auto* input_act_left_node = matched.at("input_act_left_node"); - auto* quant_dequant_left_node = matched.at("quant_dequant_left_node"); - auto* output_scale_left_node = matched.at("output_scale_left_node"); - auto* output_act_left_node = matched.at("output_act_left_node"); - - auto* input_scale_right_node = matched.at("input_scale_right_node"); - auto* input_act_right_node = matched.at("input_act_right_node"); - auto* quant_dequant_right_node = matched.at("quant_dequant_right_node"); - auto* output_scale_right_node = matched.at("output_scale_right_node"); - auto* output_act_right_node = matched.at("output_act_right_node"); - - auto* quantized_node = matched.at("quantized_node"); - - // obtain values, save values and relink node - int bit_length = - quant_dequant_left_node->stmt()->op_info()->GetAttr("bit_length"); - int range = ((1 << (bit_length - 1)) - 1); - auto* scope = quant_dequant_left_node->stmt()->op()->scope(); - auto* left_scale_tensor = - scope->FindVar(output_scale_left_node->arg()->name) - ->GetMutable(); - float left_scale_value = left_scale_tensor->data()[0] / range; - auto* right_scale_tensor = - scope->FindVar(output_scale_right_node->arg()->name) - ->GetMutable(); - float right_scale_value = right_scale_tensor->data()[0] / range; - - auto* op_desc = quantized_node->stmt()->mutable_op_info(); - op_desc->SetAttr("bit_length", bit_length); - op_desc->SetAttr("x_input_scale", left_scale_value); - op_desc->SetAttr("y_input_scale", right_scale_value); - op_desc->SetInput("X", {input_act_left_node->arg()->name}); - op_desc->SetInput("Y", {input_act_right_node->arg()->name}); - IR_NODE_LINK_TO(input_act_left_node, quantized_node) - IR_NODE_LINK_TO(input_act_right_node, quantized_node) - auto update_op_desc = *quantized_node->stmt()->mutable_op_info(); - quantized_node->stmt()->ResetOp(update_op_desc, graph->valid_places()); - - // delete nodes and edges - std::unordered_set nodes2rm = {input_scale_left_node, - quant_dequant_left_node, - output_scale_left_node, - output_act_left_node, - input_scale_right_node, - quant_dequant_right_node, - output_scale_right_node, - output_act_right_node}; - GraphSafeRemoveNodes(graph, nodes2rm); - } else { - LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_; + auto* input_scale_node = matched.at("input_scale_node"); + auto* input_act_node = matched.at("input_act_node"); + auto* quant_dequant_node = matched.at("quant_dequant_node"); + auto* output_scale_node = matched.at("output_scale_node"); + auto* output_act_node = matched.at("output_act_node"); + auto input_act_name = input_act_node->arg()->name; + auto output_act_name = output_act_node->arg()->name; + + // Get scale value from scale var node + int bit_length = + quant_dequant_node->stmt()->op_info()->GetAttr("bit_length"); + int range = ((1 << (bit_length - 1)) - 1); + auto* scope = quant_dequant_node->stmt()->op()->scope(); + auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name) + ->GetMutable(); + float scale_value = scale_tensor->data()[0] / range; + + auto quantized_nodes = output_act_node->outlinks; + for (auto* quantized_node : quantized_nodes) { + // Save quantization info in op_info attr + auto op_info = *quantized_node->stmt()->op_info(); + std::string argname; + int index; + op_info.GetInputArgname(output_act_name, &argname); + op_info.GetInputIndex(output_act_name, &index); + op_info.SetAttr(argname + std::to_string(index) + "_input_scale", + scale_value); + op_info.SetAttr("input_scale", scale_value); // Save it for now + op_info.SetAttr("bit_length", bit_length); + + op_info.UpdateAllInputs(output_act_name, input_act_name); + quantized_node->stmt()->ResetOp(op_info, graph->valid_places()); + IR_NODE_LINK_TO(input_act_node, quantized_node); } + // delete nodes and edges + std::unordered_set nodes2rm = { + input_scale_node, quant_dequant_node, output_scale_node, output_act_node}; + GraphSafeRemoveNodes(graph, nodes2rm); } cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.h b/lite/core/mir/fusion/quant_dequant_op_fuser.h index bef9f4d957..ac3ac112b3 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.h +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.h @@ -87,24 +87,16 @@ class ChannelWiseDequantOpFuser : public FuseBase { }; /* The pattern like "fake_quantize_dequantize_moving_average_abs_max + - * pooled/elementwise_add" can be deteted by this fuser. The fuser - * extract the input_scale form fake_quant_dequant_op and save into - * the quantized_op. Besides, the fuser delete fake_quant_dequant_op in - * the graph. + * quantized_op" can be deteted by this fuser. The fuser modifies the input + * scale for the quantized_op and deletes the fake_quant_dequant_op. */ - class DeleteQuantDequantOpFuser : public FuseBase { public: - explicit DeleteQuantDequantOpFuser(const std::string& quantized_op_type) - : quantized_op_type_(quantized_op_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; - - private: - std::string quantized_op_type_{}; }; } // namespace fusion diff --git a/lite/core/op_lite.h b/lite/core/op_lite.h index 4c6c66be7e..1cdc33825c 100644 --- a/lite/core/op_lite.h +++ b/lite/core/op_lite.h @@ -225,6 +225,32 @@ class OpInfo : public cpp::OpDesc { return false; } + // For the input variable name, find the index of the corresponding + // input argname + bool GetInputIndex(const std::string &value_name, int *out) const { + for (auto &item : inputs_) { + auto it = std::find(item.second.begin(), item.second.end(), value_name); + if (it != item.second.end()) { + *out = it - item.second.begin(); + return true; + } + } + return false; + } + + // For the output variable name, find the index of the corresponding + // output argname + bool GetOutputIndex(const std::string &value_name, int *out) const { + for (auto &item : outputs_) { + auto it = std::find(item.second.begin(), item.second.end(), value_name); + if (it != item.second.end()) { + *out = it - item.second.begin(); + return true; + } + } + return false; + } + void UpdateAllInputs(const std::string &from, const std::string &to) { for (auto &item : inputs_) { for (auto &var : item.second) { -- GitLab