From ee94db869a06307699903468d1c0b67e3a349520 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Mon, 29 Jun 2020 22:47:31 +0800 Subject: [PATCH] [Core ] Update quantization: save scales in op attrs by (#3816) * Update quantization, scale save in op attrs by , test=develop Co-authored-by: hong19860320 <9973393+hong19860320@users.noreply.github.com> --- lite/core/CMakeLists.txt | 2 +- lite/core/mir/fusion/conv_bn_fuser.cc | 7 +- lite/core/mir/fusion/fc_fuser.cc | 33 ++++- .../core/mir/fusion/quant_dequant_op_fuser.cc | 35 ++---- .../quantized_op_attributes_inference_pass.cc | 47 ++++--- lite/core/mir/static_kernel_pick_pass.cc | 9 +- lite/core/mir/subgraph/subgraph_detector.cc | 10 +- lite/core/mir/type_precision_cast_pass.cc | 8 +- lite/core/op_lite.cc | 118 ++++++++++++++++++ lite/core/op_lite.h | 71 +++++------ lite/kernels/apu/bridges/conv_op.cc | 16 ++- lite/kernels/apu/bridges/fc_op.cc | 15 ++- lite/kernels/apu/bridges/pool_op.cc | 10 +- lite/kernels/apu/bridges/softmax_op.cc | 10 +- lite/operators/conv_op.h | 20 +-- lite/operators/fc_op.cc | 21 ++-- 16 files changed, 296 insertions(+), 136 deletions(-) diff --git a/lite/core/CMakeLists.txt b/lite/core/CMakeLists.txt index 56a5c9b8f7..af2bfbe86a 100644 --- a/lite/core/CMakeLists.txt +++ b/lite/core/CMakeLists.txt @@ -121,7 +121,7 @@ lite_cc_library(kernel SRCS kernel.cc PROFILE_DEPS lite_profiler ) lite_cc_library(op SRCS op_lite.cc DEPS scope op_registry target_wrapper kernel - cpp_op_desc tensor + cpp_op_desc tensor utils ) add_dependencies(kernel kernel_list_h) diff --git a/lite/core/mir/fusion/conv_bn_fuser.cc b/lite/core/mir/fusion/conv_bn_fuser.cc index 69be8dab0a..abfb47f305 100644 --- a/lite/core/mir/fusion/conv_bn_fuser.cc +++ b/lite/core/mir/fusion/conv_bn_fuser.cc @@ -156,12 +156,13 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { // little difference for int8 /////////////////////////////////////////////////////////////////////////////// if (enable_int8) { - PADDLE_ENFORCE(conv_op_desc->HasAttr("weight_scale"), + std::string weight_name = conv_op_desc->Input("Filter").front(); + PADDLE_ENFORCE(conv_op_desc->HasInputScale(weight_name), "INT8 mode: Conv should has weight_scale attr"); auto conv_weight_d = conv_weight_t->mutable_data(); // compute new conv_weight for int8 auto weight_scale = - conv_op_desc->GetAttr>("weight_scale"); + conv_op_desc->GetInputScale>(weight_name); if (conv_type_ == "conv2d_transpose" && !depthwise) { int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; @@ -188,7 +189,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { } } } - conv_op_desc->SetAttr("weight_scale", weight_scale); + conv_op_desc->SetInputScale(weight_name, weight_scale); } else if (is_weight_quantization) { std::string scale_name = conv_weight_name + "_quant_scale"; if (conv_op_desc->HasAttr(scale_name)) { diff --git a/lite/core/mir/fusion/fc_fuser.cc b/lite/core/mir/fusion/fc_fuser.cc index 3c99131083..e0254aca97 100644 --- a/lite/core/mir/fusion/fc_fuser.cc +++ b/lite/core/mir/fusion/fc_fuser.cc @@ -71,7 +71,27 @@ void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { } cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { - cpp::OpDesc op_desc = *matched.at("mul")->stmt()->op_info(); + auto op_desc = *matched.at("mul")->stmt()->op_info(); + + // Get the input scale from mul + float x_scale{}; + std::vector y_scale_vct; + auto y_var_node = matched.at("W")->arg(); + auto input_x_name = op_desc.Input("X").front(); + auto input_y_name = op_desc.Input("Y").front(); + bool is_quantized_op = op_desc.HasInputScale(input_x_name) && + op_desc.HasInputScale(input_y_name); + if (is_quantized_op) { + x_scale = op_desc.GetInputScale(input_x_name); + if (y_var_node->is_weight) { // the scale of y is a vector + y_scale_vct = + op_desc.GetInputScale>(op_desc.Input("Y").front()); + } else { + y_scale_vct.push_back( // the scale of y is scalar + op_desc.GetInputScale(op_desc.Input("Y").front())); + } + } + op_desc.mutable_inputs()->clear(); op_desc.mutable_outputs()->clear(); op_desc.SetType("fc"); @@ -85,6 +105,17 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { if (with_relu_) { op_desc.SetAttr("activation_type", std::string{"relu"}); } + + // Set the input scale into fc + if (is_quantized_op) { + op_desc.SetInputScale(matched.at("x")->arg()->name, x_scale); + if (y_var_node->is_weight) { + op_desc.SetInputScale(matched.at("W")->arg()->name, y_scale_vct); + } else { + op_desc.SetInputScale(matched.at("W")->arg()->name, y_scale_vct.front()); + } + } + return op_desc; } diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index f6d03cc23d..b71fde125c 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -64,13 +64,7 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph, for (auto* quantized_node : outlinks) { // save input scale in quantized op by input argname + index auto op_desc = *quantized_node->stmt()->mutable_op_info(); - std::string argname; - int index; - op_desc.GetInputArgname(out_act_name, &argname); - op_desc.GetInputIndex(out_act_name, &index); - op_desc.SetAttr(argname + std::to_string(index) + "_input_scale", - scale_value); - op_desc.SetAttr("input_scale", scale_value); // save it for now + op_desc.SetInputScale(out_act_name, scale_value); op_desc.SetAttr("bit_length", bit_length); op_desc.UpdateAllInputs(out_act_name, in_act_name); quantized_node->stmt()->ResetOp(op_desc, graph->valid_places()); @@ -135,6 +129,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, auto* quantized_op = matched.at("quantized_op"); auto* dequant_op = matched.at("dequant_op"); auto* dequant_op_out = matched.at("dequant_op_out"); + auto weight_name = quantized_op_weight->arg()->name; // obtain weight_scale from max_range auto* scope = quantized_op->stmt()->op()->scope(); @@ -150,7 +145,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, // = max(abs(weight)) / range // set op desc - cpp::OpDesc op_desc = *quantized_op->stmt()->op_info(); + auto op_desc = *quantized_op->stmt()->op_info(); auto quantized_weight_var_name = quantized_op_weight->arg()->name; auto quantized_weight_t = scope->FindVar(quantized_weight_var_name)->GetMutable(); @@ -173,7 +168,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, weight_scale.push_back(whole_weight_scale); } op_desc.SetAttr("enable_int8", true); - op_desc.SetAttr("weight_scale", weight_scale); + op_desc.SetInputScale(weight_name, weight_scale); // change the weight from the float type to int8 type. Tensor temp_tensor; @@ -246,6 +241,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, auto* dequant_op_channel_scale = matched.at("dequant_op_channel_scale"); auto* dequant_op = matched.at("dequant_op"); auto* dequant_op_out = matched.at("dequant_op_out"); + auto weight_name = quantized_op_weight->arg()->name; // obtain input weight_scale from fake_dequant op auto* scope = quantized_op->stmt()->op()->scope(); @@ -265,7 +261,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, } // set op desc - cpp::OpDesc op_desc = *quantized_op->stmt()->op_info(); + auto op_desc = *quantized_op->stmt()->op_info(); if (quantized_op_type_ == "conv2d" || quantized_op_type_ == "depthwise_conv2d") { op_desc.SetInput("Input", {quantized_op_input->arg()->name}); @@ -275,7 +271,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); } op_desc.SetAttr("enable_int8", true); - op_desc.SetAttr("weight_scale", weight_scale); + op_desc.SetInputScale(weight_name, weight_scale); // change the weight from the float type to int8 type. auto quantized_weight_var_name = quantized_op_weight->arg()->name; @@ -352,22 +348,7 @@ void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, // Save quantization info in op_info attr auto op_info = *quantized_node->stmt()->op_info(); op_info.SetAttr("bit_length", bit_length); - - std::string argname; - int index; - op_info.GetInputArgname(output_act_name, &argname); - op_info.GetInputIndex(output_act_name, &index); - op_info.SetAttr(argname + std::to_string(index) + "_input_scale", - scale_value); - std::string op_type = op_info.Type(); - // Analyse the weight scale or input scale. - if (((op_type == "conv2d" || op_type == "depthwise_conv2d") && - argname == "Input") || - ((op_type == "mul" || op_type == "matmul") && argname == "Y")) { - op_info.SetAttr("weight_scale", scale_value); - } else { - op_info.SetAttr("input_scale", scale_value); - } + op_info.SetInputScale(output_act_name, scale_value); op_info.UpdateAllInputs(output_act_name, input_act_name); quantized_node->stmt()->ResetOp(op_info, graph->valid_places()); diff --git a/lite/core/mir/quantized_op_attributes_inference_pass.cc b/lite/core/mir/quantized_op_attributes_inference_pass.cc index 66b37446a4..d1834a70e3 100644 --- a/lite/core/mir/quantized_op_attributes_inference_pass.cc +++ b/lite/core/mir/quantized_op_attributes_inference_pass.cc @@ -37,17 +37,30 @@ void QuantizedOpAttributesInferencePass::Apply( auto& inst = op_node->AsStmt(); auto op_info = inst.op_info(); auto op_type = op_info->Type(); - if (!op_info->HasAttr("input_scale")) continue; - bool found = false; - float output_scale; + + // Check only if all of the inputs of the op have scale value + bool has_input_scale = true; + for (auto in_var_node : op_node->inlinks) { + CHECK(in_var_node->IsArg()); + auto in_var_node_name = in_var_node->arg()->name; + has_input_scale &= op_info->HasInputScale(in_var_node_name); + } + if (!has_input_scale) continue; + + // Infer the output scale according to its out_threshold or the input scale + // of its adjacent ops + bool is_quantized = true; for (auto out_var_node : op_node->outlinks) { CHECK(out_var_node->IsArg()); + bool found = false; + float output_scale; + auto out_var_node_name = out_var_node->arg()->name; for (auto out_op_node : out_var_node->outlinks) { CHECK(out_op_node->IsStmt()); auto& out_inst = out_op_node->AsStmt(); auto out_op_info = out_inst.op_info(); - if (!out_op_info->HasAttr("input_scale")) continue; - auto input_scale = out_op_info->GetAttr("input_scale"); + if (!out_op_info->HasInputScale(out_var_node_name)) continue; + auto input_scale = out_op_info->GetInputScale(out_var_node_name); if (!found) { found = true; output_scale = input_scale; @@ -55,16 +68,22 @@ void QuantizedOpAttributesInferencePass::Apply( CHECK_EQ(output_scale, input_scale); } } + if (found) { + inst.mutable_op_info()->SetOutputScale(out_var_node_name, output_scale); + } else if (op_info->HasAttr("out_threshold")) { + // Only consider one output, there are only one out_threshold + int bit_length = op_info->GetAttr("bit_length"); + int range = (1 << (bit_length - 1)) - 1; + output_scale = op_info->GetAttr("out_threshold"); + inst.mutable_op_info()->SetOutputScale(out_var_node_name, + output_scale / range); + } else { + is_quantized = false; + } } - if (found) { - inst.mutable_op_info()->SetAttr("output_scale", output_scale); - } else if (op_info->HasAttr("output_scale")) { - int bit_length = op_info->GetAttr("bit_length"); - int range = (1 << (bit_length - 1)) - 1; - output_scale = op_info->GetAttr("output_scale"); - inst.mutable_op_info()->SetAttr("output_scale", output_scale / range); - } - if (op_info->HasAttr("output_scale")) { + + // Fix the missing of the attribute 'enable_int8'. + if (is_quantized) { inst.mutable_op_info()->SetAttr("enable_int8", true); } } diff --git a/lite/core/mir/static_kernel_pick_pass.cc b/lite/core/mir/static_kernel_pick_pass.cc index 1de0d1a265..849c9013ff 100644 --- a/lite/core/mir/static_kernel_pick_pass.cc +++ b/lite/core/mir/static_kernel_pick_pass.cc @@ -110,15 +110,16 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { if (out_type_int8) { auto out_node = node.outlinks.front(); CHECK(out_node->IsArg()); + auto out_node_name = out_node->arg()->name; auto one_adj_op_node = out_node->outlinks.front(); CHECK(one_adj_op_node->IsStmt()); auto& one_adj_instruct = one_adj_op_node->AsStmt(); CHECK(one_adj_instruct.op_info()->HasAttr("enable_int8")); - CHECK(one_adj_instruct.op_info()->HasAttr("input_scale")); + CHECK(one_adj_instruct.op_info()->HasInputScale(out_node_name)); - instruct.mutable_op_info()->SetAttr( - "output_scale", - one_adj_instruct.op_info()->GetAttr("input_scale")); + instruct.mutable_op_info()->SetOutputScale( + out_node_name, + one_adj_instruct.op_info()->GetInputScale(out_node_name)); auto update_desc = *instruct.mutable_op_info(); instruct.ResetOp(update_desc, graph->valid_places()); diff --git a/lite/core/mir/subgraph/subgraph_detector.cc b/lite/core/mir/subgraph/subgraph_detector.cc index 31a38280ff..9d42cca204 100644 --- a/lite/core/mir/subgraph/subgraph_detector.cc +++ b/lite/core/mir/subgraph/subgraph_detector.cc @@ -457,21 +457,23 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, std::vector input_data_scales; std::vector output_data_scales; for (auto &var_node : input_var_nodes) { + auto var_node_name = var_node->arg()->name; auto any_op_node = var_node->outlinks.front(); CHECK(any_op_node->IsStmt()); auto &any_inst = any_op_node->AsStmt(); - if (any_inst.op_info()->HasAttr("input_scale")) { + if (any_inst.op_info()->HasInputScale(var_node_name)) { input_data_scales.push_back( - any_inst.op_info()->GetAttr("input_scale")); + any_inst.op_info()->GetInputScale(var_node_name)); } } for (auto &var_node : output_var_nodes) { + auto var_node_name = var_node->arg()->name; auto any_op_node = var_node->inlinks.front(); CHECK(any_op_node->IsStmt()); auto &any_inst = any_op_node->AsStmt(); - if (any_inst.op_info()->HasAttr("output_scale")) { + if (any_inst.op_info()->HasOutputScale(var_node_name)) { output_data_scales.push_back( - any_inst.op_info()->GetAttr("output_scale")); + any_inst.op_info()->GetOutputScale(var_node_name)); } } if (input_data_scales.size() > 0) { diff --git a/lite/core/mir/type_precision_cast_pass.cc b/lite/core/mir/type_precision_cast_pass.cc index 2564887756..0e90fcb96f 100644 --- a/lite/core/mir/type_precision_cast_pass.cc +++ b/lite/core/mir/type_precision_cast_pass.cc @@ -107,8 +107,8 @@ static bool InferScale(Node* var_node, Node* op_node, float* scale) { if (op_type == "subgraph") { found = InferScaleFromSubgraph(var_name, op_info, scale, false); } else { - if (op_info->HasAttr("input_scale")) { - *scale = op_info->GetAttr("input_scale"); + if (op_info->HasInputScale(var_name)) { + *scale = op_info->GetInputScale(var_name); found = true; } else { // Obtain the output_scale from one of its previous Ops @@ -120,8 +120,8 @@ static bool InferScale(Node* var_node, Node* op_node, float* scale) { if (prev_op_type == "subgraph") { found = InferScaleFromSubgraph(var_name, prev_op_info, scale, true); } else { - if (prev_op_info->HasAttr("output_scale")) { - *scale = prev_op_info->GetAttr("output_scale"); + if (prev_op_info->HasOutputScale(var_name)) { + *scale = prev_op_info->GetOutputScale(var_name); found = true; } } diff --git a/lite/core/op_lite.cc b/lite/core/op_lite.cc index 537636065d..199bd69bfd 100644 --- a/lite/core/op_lite.cc +++ b/lite/core/op_lite.cc @@ -22,6 +22,14 @@ namespace paddle { namespace lite { +std::string int2string(int index) { + const int BUFFER_LENGTH = 30; + char buffer[BUFFER_LENGTH]; + int num = snprintf(buffer, sizeof(buffer), "%d", index); + CHECK(num > 0 && num < sizeof(buffer)); + return std::string(buffer); +} + bool OpLite::InferShape() { // if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_ // InferShapeByMemoryInternal will be applied. @@ -186,5 +194,115 @@ void OpLite::AttachOutput(const cpp::OpDesc &op_desc, } } +bool OpInfo::GetInputArgname(const std::string &value_name, + std::string *out) const { + for (auto &item : inputs_) { + auto it = std::find(item.second.begin(), item.second.end(), value_name); + if (it != item.second.end()) { + *out = item.first; + return true; + } + } + return false; +} + +bool OpInfo::GetOutputArgname(const std::string &value_name, + std::string *out) const { + for (auto &item : outputs_) { + auto it = std::find(item.second.begin(), item.second.end(), value_name); + if (it != item.second.end()) { + *out = item.first; + return true; + } + } + return false; +} + +bool OpInfo::GetInputIndex(const std::string &input_name, int *out) const { + for (auto &item : inputs_) { + auto it = std::find(item.second.begin(), item.second.end(), input_name); + if (it != item.second.end()) { + *out = it - item.second.begin(); + return true; + } + } + return false; +} + +bool OpInfo::GetOutputIndex(const std::string &output_name, int *out) const { + for (auto &item : outputs_) { + auto it = std::find(item.second.begin(), item.second.end(), output_name); + if (it != item.second.end()) { + *out = it - item.second.begin(); + return true; + } + } + return false; +} + +bool OpInfo::HasInputScale(const std::string &input_name) const { + std::string argname; + int index; + if (GetInputArgname(input_name, &argname) && + GetInputIndex(input_name, &index)) { + return HasAttr(argname + int2string(index) + "_scale"); + } else { + return false; + } +} + +bool OpInfo::HasOutputScale(const std::string &output_name) const { + std::string argname; + int index; + if (GetOutputArgname(output_name, &argname) && + GetOutputIndex(output_name, &index)) { + return HasAttr(argname + int2string(index) + "_scale"); + } else { + return false; + } +} + +template <> +void OpInfo::SetInputScale(const std::string &input_name, + const float &scale_value) { + std::string argname; + int index; + CHECK(GetInputArgname(input_name, &argname)); + CHECK(GetInputIndex(input_name, &index)); + SetAttr(argname + int2string(index) + "_scale", scale_value); +} + +template <> +void OpInfo::SetInputScale(const std::string &input_name, + const std::vector &scale_value) { + std::string argname; + int index; + CHECK(GetInputArgname(input_name, &argname)); + CHECK(GetInputIndex(input_name, &index)); + SetAttr>(argname + int2string(index) + "_scale", + scale_value); +} + +template <> +void OpInfo::SetOutputScale(const std::string &output_name, + const float &scale_value) { + std::string argname; + int index; + CHECK(GetOutputArgname(output_name, &argname)); + CHECK(GetOutputIndex(output_name, &index)); + SetAttr(argname + int2string(index) + "_scale", scale_value); +} + +template <> +void OpInfo::SetOutputScale(const std::string &output_name, + const std::vector &scale_value) { + std::string argname; + int index; + CHECK(GetOutputArgname(output_name, &argname)); + CHECK(GetOutputIndex(output_name, &index)); + SetAttr>(argname + int2string(index) + "_scale", + scale_value); +} + } // namespace lite } // namespace paddle diff --git a/lite/core/op_lite.h b/lite/core/op_lite.h index 301065d5b6..8e0a302eb3 100644 --- a/lite/core/op_lite.h +++ b/lite/core/op_lite.h @@ -30,6 +30,8 @@ namespace paddle { namespace lite { +std::string int2string(int index); + // For registry factory. struct Registry { void Touch() {} @@ -229,51 +231,36 @@ class OpInfo : public cpp::OpDesc { return OutputArgumentNames(); } - bool GetInputArgname(const std::string &value_name, std::string *out) const { - for (auto &item : inputs_) { - auto it = std::find(item.second.begin(), item.second.end(), value_name); - if (it != item.second.end()) { - *out = item.first; - return true; - } - } - return false; - } - bool GetOutputArgname(const std::string &value_name, std::string *out) const { - for (auto &item : outputs_) { - auto it = std::find(item.second.begin(), item.second.end(), value_name); - if (it != item.second.end()) { - *out = item.first; - return true; - } - } - return false; - } + bool GetInputArgname(const std::string &value_name, std::string *out) const; + bool GetOutputArgname(const std::string &value_name, std::string *out) const; - // For the input variable name, find the index of the corresponding - // input argname - bool GetInputIndex(const std::string &value_name, int *out) const { - for (auto &item : inputs_) { - auto it = std::find(item.second.begin(), item.second.end(), value_name); - if (it != item.second.end()) { - *out = it - item.second.begin(); - return true; - } - } - return false; + bool GetInputIndex(const std::string &input_name, int *out) const; + bool GetOutputIndex(const std::string &output_name, int *out) const; + + bool HasInputScale(const std::string &input_name) const; + bool HasOutputScale(const std::string &output_name) const; + + template + void SetInputScale(const std::string &input_name, const T &scale_value); + template + void SetOutputScale(const std::string &output_name, const T &scale_value); + + template + T GetInputScale(const std::string &input_name) const { + std::string argname; + int index; + CHECK(GetInputArgname(input_name, &argname)); + CHECK(GetInputIndex(input_name, &index)); + return GetAttr(argname + int2string(index) + "_scale"); } - // For the output variable name, find the index of the corresponding - // output argname - bool GetOutputIndex(const std::string &value_name, int *out) const { - for (auto &item : outputs_) { - auto it = std::find(item.second.begin(), item.second.end(), value_name); - if (it != item.second.end()) { - *out = it - item.second.begin(); - return true; - } - } - return false; + template + T GetOutputScale(const std::string &output_name) const { + std::string argname; + int index; + CHECK(GetOutputArgname(output_name, &argname)); + CHECK(GetOutputIndex(output_name, &index)); + return GetAttr(argname + int2string(index) + "_scale"); } void UpdateAllInputs(const std::string &from, const std::string &to) { diff --git a/lite/kernels/apu/bridges/conv_op.cc b/lite/kernels/apu/bridges/conv_op.cc index ca6e0ff2ac..a960269c94 100644 --- a/lite/kernels/apu/bridges/conv_op.cc +++ b/lite/kernels/apu/bridges/conv_op.cc @@ -99,12 +99,16 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { std::vector weight_scale; if (op_info->HasAttr("enable_int8")) { if (op_info->GetAttr("enable_int8")) { - if (op_info->HasAttr("input_scale")) - input_scale = op_info->GetAttr("input_scale"); - if (op_info->HasAttr("weight_scale")) - weight_scale = op_info->GetAttr>("weight_scale"); - if (op_info->HasAttr("output_scale")) - output_scale = op_info->GetAttr("output_scale"); + auto input_name = op_info->Input("Input").front(); + auto filter_name = op_info->Input("Filter").front(); + auto output_name = op_info->Output("Output").front(); + if (op_info->HasInputScale(input_name)) + input_scale = op_info->GetInputScale(input_name); + if (op_info->HasInputScale(filter_name)) + weight_scale = op_info->GetInputScale>(filter_name); + if (op_info->HasOutputScale(output_name)) { + output_scale = op_info->GetOutputScale(output_name); + } VLOG(3) << "has output scale:" << output_scale; } else { return FAILED; diff --git a/lite/kernels/apu/bridges/fc_op.cc b/lite/kernels/apu/bridges/fc_op.cc index a00a35f9a0..48a439482c 100644 --- a/lite/kernels/apu/bridges/fc_op.cc +++ b/lite/kernels/apu/bridges/fc_op.cc @@ -57,12 +57,15 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { std::vector w_scale; if (op_info->HasAttr("enable_int8")) { if (op_info->GetAttr("enable_int8")) { - if (op_info->HasAttr("input_scale")) - input_scale = op_info->GetAttr("input_scale"); - if (op_info->HasAttr("weight_scale")) - w_scale = op_info->GetAttr>("weight_scale"); - if (op_info->HasAttr("output_scale")) - out_scale = op_info->GetAttr("output_scale"); + auto input_name = op_info->Input("Input").front(); + auto weight_name = op_info->Input("W").front(); + auto out_name = op_info->Output("Out").front(); + if (op_info->HasInputScale(input_name)) + input_scale = op_info->GetInputScale(input_name); + if (op_info->HasInputScale(weight_name)) + w_scale = op_info->GetInputScale>(weight_name); + if (op_info->HasOutputScale(out_name)) + out_scale = op_info->GetOutputScale(out_name); } else { return FAILED; } diff --git a/lite/kernels/apu/bridges/pool_op.cc b/lite/kernels/apu/bridges/pool_op.cc index 2bda76ab99..66ef493d38 100644 --- a/lite/kernels/apu/bridges/pool_op.cc +++ b/lite/kernels/apu/bridges/pool_op.cc @@ -91,10 +91,12 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { float out_scale = 1.0f; if (op_info->HasAttr("enable_int8")) { if (op_info->GetAttr("enable_int8")) { - if (op_info->HasAttr("input_scale")) - x_scale = op_info->GetAttr("input_scale"); - if (op_info->HasAttr("output_scale")) - out_scale = op_info->GetAttr("output_scale"); + auto x_name = op_info->Input("X").front(); + auto out_name = op_info->Output("Out").front(); + if (op_info->HasInputScale(x_name)) + x_scale = op_info->GetInputScale(x_name); + if (op_info->HasOutputScale(out_name)) + out_scale = op_info->GetOutputScale(out_name); } else { LOG(WARNING) << "Do not enable_int8"; return FAILED; diff --git a/lite/kernels/apu/bridges/softmax_op.cc b/lite/kernels/apu/bridges/softmax_op.cc index 6a289ac987..06ce59d597 100644 --- a/lite/kernels/apu/bridges/softmax_op.cc +++ b/lite/kernels/apu/bridges/softmax_op.cc @@ -49,10 +49,12 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { float out_scale = 1.0f; if (op_info->HasAttr("enable_int8")) { if (op_info->GetAttr("enable_int8")) { - if (op_info->HasAttr("input_scale")) - input_scale = op_info->GetAttr("input_scale"); - if (op_info->HasAttr("output_scale")) - out_scale = op_info->GetAttr("output_scale"); + auto x_name = op_info->Input("X").front(); + auto out_name = op_info->Output("Out").front(); + if (op_info->HasInputScale(x_name)) + input_scale = op_info->GetInputScale(x_name); + if (op_info->HasOutputScale(out_name)) + out_scale = op_info->GetOutputScale(out_name); } else { LOG(WARNING) << "Do not enable_int8"; return FAILED; diff --git a/lite/operators/conv_op.h b/lite/operators/conv_op.h index c3e375e2e4..a71041bcf3 100644 --- a/lite/operators/conv_op.h +++ b/lite/operators/conv_op.h @@ -130,15 +130,19 @@ class ConvOpLite : public OpLite { padding_algorithm_ = op_desc.GetAttr("padding_algorithm"); } // For Int8 - if (op_desc.HasAttr("enable_int8")) { - param_.enable_int8 = op_desc.GetAttr("enable_int8"); - if (op_desc.HasAttr("input_scale")) - param_.input_scale = op_desc.GetAttr("input_scale"); - if (op_desc.HasAttr("weight_scale")) + const OpInfo* op_info = dynamic_cast(&op_desc); + if (op_info != nullptr && op_info->HasAttr("enable_int8")) { + param_.enable_int8 = op_info->GetAttr("enable_int8"); + auto input_name = op_info->Input("Input").front(); + auto filter_name = op_info->Input("Filter").front(); + auto output_name = op_info->Output("Output").front(); + if (op_info->HasInputScale(input_name)) + param_.input_scale = op_info->GetInputScale(input_name); + if (op_info->HasInputScale(filter_name)) param_.weight_scale = - op_desc.GetAttr>("weight_scale"); - if (op_desc.HasAttr("output_scale")) { - param_.output_scale = op_desc.GetAttr("output_scale"); + op_info->GetInputScale>(filter_name); + if (op_info->HasOutputScale(output_name)) { + param_.output_scale = op_info->GetOutputScale(output_name); } } diff --git a/lite/operators/fc_op.cc b/lite/operators/fc_op.cc index d4032c5e8b..0dade760ce 100644 --- a/lite/operators/fc_op.cc +++ b/lite/operators/fc_op.cc @@ -102,14 +102,19 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { } // For Int8 - if (op_desc.HasAttr("enable_int8")) { - param_.enable_int8 = op_desc.GetAttr("enable_int8"); - if (op_desc.HasAttr("input_scale")) - param_.input_scale = op_desc.GetAttr("input_scale"); - if (op_desc.HasAttr("weight_scale")) - param_.weight_scale = op_desc.GetAttr>("weight_scale"); - if (op_desc.HasAttr("output_scale")) - param_.output_scale = op_desc.GetAttr("output_scale"); + const OpInfo* op_info = dynamic_cast(&op_desc); + if (op_info != nullptr && op_info->HasAttr("enable_int8")) { + param_.enable_int8 = op_info->GetAttr("enable_int8"); + auto input_name = op_info->Input("Input").front(); + auto weight_name = op_info->Input("W").front(); + auto out_name = op_info->Output("Out").front(); + if (op_info->HasInputScale(input_name)) + param_.input_scale = op_info->GetInputScale(input_name); + if (op_info->HasInputScale(weight_name)) + param_.weight_scale = + op_info->GetInputScale>(weight_name); + if (op_info->HasOutputScale(out_name)) + param_.output_scale = op_info->GetOutputScale(out_name); } return true; } -- GitLab