未验证 提交 ee94db86 编写于 作者: C cc 提交者: GitHub

[Core ] Update quantization: save scales in op attrs by <inputname_index, scale_value> (#3816)

* Update quantization, scale save in op attrs by  <inputname_index, scale_value>, test=develop
Co-authored-by: Nhong19860320 <9973393+hong19860320@users.noreply.github.com>
上级 39b78a4c
...@@ -121,7 +121,7 @@ lite_cc_library(kernel SRCS kernel.cc ...@@ -121,7 +121,7 @@ lite_cc_library(kernel SRCS kernel.cc
PROFILE_DEPS lite_profiler PROFILE_DEPS lite_profiler
) )
lite_cc_library(op SRCS op_lite.cc DEPS scope op_registry target_wrapper kernel lite_cc_library(op SRCS op_lite.cc DEPS scope op_registry target_wrapper kernel
cpp_op_desc tensor cpp_op_desc tensor utils
) )
add_dependencies(kernel kernel_list_h) add_dependencies(kernel kernel_list_h)
......
...@@ -156,12 +156,13 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -156,12 +156,13 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
// little difference for int8 // little difference for int8
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
if (enable_int8) { if (enable_int8) {
PADDLE_ENFORCE(conv_op_desc->HasAttr("weight_scale"), std::string weight_name = conv_op_desc->Input("Filter").front();
PADDLE_ENFORCE(conv_op_desc->HasInputScale(weight_name),
"INT8 mode: Conv should has weight_scale attr"); "INT8 mode: Conv should has weight_scale attr");
auto conv_weight_d = conv_weight_t->mutable_data<int8_t>(); auto conv_weight_d = conv_weight_t->mutable_data<int8_t>();
// compute new conv_weight for int8 // compute new conv_weight for int8
auto weight_scale = auto weight_scale =
conv_op_desc->GetAttr<std::vector<float>>("weight_scale"); conv_op_desc->GetInputScale<std::vector<float>>(weight_name);
if (conv_type_ == "conv2d_transpose" && !depthwise) { if (conv_type_ == "conv2d_transpose" && !depthwise) {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] * int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] *
conv_weight_t->dims()[3]; conv_weight_t->dims()[3];
...@@ -188,7 +189,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -188,7 +189,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
} }
} }
} }
conv_op_desc->SetAttr("weight_scale", weight_scale); conv_op_desc->SetInputScale(weight_name, weight_scale);
} else if (is_weight_quantization) { } else if (is_weight_quantization) {
std::string scale_name = conv_weight_name + "_quant_scale"; std::string scale_name = conv_weight_name + "_quant_scale";
if (conv_op_desc->HasAttr(scale_name)) { if (conv_op_desc->HasAttr(scale_name)) {
......
...@@ -71,7 +71,27 @@ void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -71,7 +71,27 @@ void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
} }
cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("mul")->stmt()->op_info(); auto op_desc = *matched.at("mul")->stmt()->op_info();
// Get the input scale from mul
float x_scale{};
std::vector<float> y_scale_vct;
auto y_var_node = matched.at("W")->arg();
auto input_x_name = op_desc.Input("X").front();
auto input_y_name = op_desc.Input("Y").front();
bool is_quantized_op = op_desc.HasInputScale(input_x_name) &&
op_desc.HasInputScale(input_y_name);
if (is_quantized_op) {
x_scale = op_desc.GetInputScale<float>(input_x_name);
if (y_var_node->is_weight) { // the scale of y is a vector
y_scale_vct =
op_desc.GetInputScale<std::vector<float>>(op_desc.Input("Y").front());
} else {
y_scale_vct.push_back( // the scale of y is scalar
op_desc.GetInputScale<float>(op_desc.Input("Y").front()));
}
}
op_desc.mutable_inputs()->clear(); op_desc.mutable_inputs()->clear();
op_desc.mutable_outputs()->clear(); op_desc.mutable_outputs()->clear();
op_desc.SetType("fc"); op_desc.SetType("fc");
...@@ -85,6 +105,17 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -85,6 +105,17 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
if (with_relu_) { if (with_relu_) {
op_desc.SetAttr("activation_type", std::string{"relu"}); op_desc.SetAttr("activation_type", std::string{"relu"});
} }
// Set the input scale into fc
if (is_quantized_op) {
op_desc.SetInputScale(matched.at("x")->arg()->name, x_scale);
if (y_var_node->is_weight) {
op_desc.SetInputScale(matched.at("W")->arg()->name, y_scale_vct);
} else {
op_desc.SetInputScale(matched.at("W")->arg()->name, y_scale_vct.front());
}
}
return op_desc; return op_desc;
} }
......
...@@ -64,13 +64,7 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -64,13 +64,7 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
for (auto* quantized_node : outlinks) { for (auto* quantized_node : outlinks) {
// save input scale in quantized op by input argname + index // save input scale in quantized op by input argname + index
auto op_desc = *quantized_node->stmt()->mutable_op_info(); auto op_desc = *quantized_node->stmt()->mutable_op_info();
std::string argname; op_desc.SetInputScale(out_act_name, scale_value);
int index;
op_desc.GetInputArgname(out_act_name, &argname);
op_desc.GetInputIndex(out_act_name, &index);
op_desc.SetAttr<float>(argname + std::to_string(index) + "_input_scale",
scale_value);
op_desc.SetAttr<float>("input_scale", scale_value); // save it for now
op_desc.SetAttr<int>("bit_length", bit_length); op_desc.SetAttr<int>("bit_length", bit_length);
op_desc.UpdateAllInputs(out_act_name, in_act_name); op_desc.UpdateAllInputs(out_act_name, in_act_name);
quantized_node->stmt()->ResetOp(op_desc, graph->valid_places()); quantized_node->stmt()->ResetOp(op_desc, graph->valid_places());
...@@ -135,6 +129,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -135,6 +129,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
auto* quantized_op = matched.at("quantized_op"); auto* quantized_op = matched.at("quantized_op");
auto* dequant_op = matched.at("dequant_op"); auto* dequant_op = matched.at("dequant_op");
auto* dequant_op_out = matched.at("dequant_op_out"); auto* dequant_op_out = matched.at("dequant_op_out");
auto weight_name = quantized_op_weight->arg()->name;
// obtain weight_scale from max_range // obtain weight_scale from max_range
auto* scope = quantized_op->stmt()->op()->scope(); auto* scope = quantized_op->stmt()->op()->scope();
...@@ -150,7 +145,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -150,7 +145,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
// = max(abs(weight)) / range // = max(abs(weight)) / range
// set op desc // set op desc
cpp::OpDesc op_desc = *quantized_op->stmt()->op_info(); auto op_desc = *quantized_op->stmt()->op_info();
auto quantized_weight_var_name = quantized_op_weight->arg()->name; auto quantized_weight_var_name = quantized_op_weight->arg()->name;
auto quantized_weight_t = auto quantized_weight_t =
scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>(); scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>();
...@@ -173,7 +168,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -173,7 +168,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
weight_scale.push_back(whole_weight_scale); weight_scale.push_back(whole_weight_scale);
} }
op_desc.SetAttr("enable_int8", true); op_desc.SetAttr("enable_int8", true);
op_desc.SetAttr("weight_scale", weight_scale); op_desc.SetInputScale(weight_name, weight_scale);
// change the weight from the float type to int8 type. // change the weight from the float type to int8 type.
Tensor temp_tensor; Tensor temp_tensor;
...@@ -246,6 +241,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -246,6 +241,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
auto* dequant_op_channel_scale = matched.at("dequant_op_channel_scale"); auto* dequant_op_channel_scale = matched.at("dequant_op_channel_scale");
auto* dequant_op = matched.at("dequant_op"); auto* dequant_op = matched.at("dequant_op");
auto* dequant_op_out = matched.at("dequant_op_out"); auto* dequant_op_out = matched.at("dequant_op_out");
auto weight_name = quantized_op_weight->arg()->name;
// obtain input weight_scale from fake_dequant op // obtain input weight_scale from fake_dequant op
auto* scope = quantized_op->stmt()->op()->scope(); auto* scope = quantized_op->stmt()->op()->scope();
...@@ -265,7 +261,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -265,7 +261,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
} }
// set op desc // set op desc
cpp::OpDesc op_desc = *quantized_op->stmt()->op_info(); auto op_desc = *quantized_op->stmt()->op_info();
if (quantized_op_type_ == "conv2d" || if (quantized_op_type_ == "conv2d" ||
quantized_op_type_ == "depthwise_conv2d") { quantized_op_type_ == "depthwise_conv2d") {
op_desc.SetInput("Input", {quantized_op_input->arg()->name}); op_desc.SetInput("Input", {quantized_op_input->arg()->name});
...@@ -275,7 +271,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -275,7 +271,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
} }
op_desc.SetAttr("enable_int8", true); op_desc.SetAttr("enable_int8", true);
op_desc.SetAttr("weight_scale", weight_scale); op_desc.SetInputScale(weight_name, weight_scale);
// change the weight from the float type to int8 type. // change the weight from the float type to int8 type.
auto quantized_weight_var_name = quantized_op_weight->arg()->name; auto quantized_weight_var_name = quantized_op_weight->arg()->name;
...@@ -352,22 +348,7 @@ void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -352,22 +348,7 @@ void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
// Save quantization info in op_info attr // Save quantization info in op_info attr
auto op_info = *quantized_node->stmt()->op_info(); auto op_info = *quantized_node->stmt()->op_info();
op_info.SetAttr<int>("bit_length", bit_length); op_info.SetAttr<int>("bit_length", bit_length);
op_info.SetInputScale(output_act_name, scale_value);
std::string argname;
int index;
op_info.GetInputArgname(output_act_name, &argname);
op_info.GetInputIndex(output_act_name, &index);
op_info.SetAttr<float>(argname + std::to_string(index) + "_input_scale",
scale_value);
std::string op_type = op_info.Type();
// Analyse the weight scale or input scale.
if (((op_type == "conv2d" || op_type == "depthwise_conv2d") &&
argname == "Input") ||
((op_type == "mul" || op_type == "matmul") && argname == "Y")) {
op_info.SetAttr<float>("weight_scale", scale_value);
} else {
op_info.SetAttr<float>("input_scale", scale_value);
}
op_info.UpdateAllInputs(output_act_name, input_act_name); op_info.UpdateAllInputs(output_act_name, input_act_name);
quantized_node->stmt()->ResetOp(op_info, graph->valid_places()); quantized_node->stmt()->ResetOp(op_info, graph->valid_places());
......
...@@ -37,17 +37,30 @@ void QuantizedOpAttributesInferencePass::Apply( ...@@ -37,17 +37,30 @@ void QuantizedOpAttributesInferencePass::Apply(
auto& inst = op_node->AsStmt(); auto& inst = op_node->AsStmt();
auto op_info = inst.op_info(); auto op_info = inst.op_info();
auto op_type = op_info->Type(); auto op_type = op_info->Type();
if (!op_info->HasAttr("input_scale")) continue;
bool found = false; // Check only if all of the inputs of the op have scale value
float output_scale; bool has_input_scale = true;
for (auto in_var_node : op_node->inlinks) {
CHECK(in_var_node->IsArg());
auto in_var_node_name = in_var_node->arg()->name;
has_input_scale &= op_info->HasInputScale(in_var_node_name);
}
if (!has_input_scale) continue;
// Infer the output scale according to its out_threshold or the input scale
// of its adjacent ops
bool is_quantized = true;
for (auto out_var_node : op_node->outlinks) { for (auto out_var_node : op_node->outlinks) {
CHECK(out_var_node->IsArg()); CHECK(out_var_node->IsArg());
bool found = false;
float output_scale;
auto out_var_node_name = out_var_node->arg()->name;
for (auto out_op_node : out_var_node->outlinks) { for (auto out_op_node : out_var_node->outlinks) {
CHECK(out_op_node->IsStmt()); CHECK(out_op_node->IsStmt());
auto& out_inst = out_op_node->AsStmt(); auto& out_inst = out_op_node->AsStmt();
auto out_op_info = out_inst.op_info(); auto out_op_info = out_inst.op_info();
if (!out_op_info->HasAttr("input_scale")) continue; if (!out_op_info->HasInputScale(out_var_node_name)) continue;
auto input_scale = out_op_info->GetAttr<float>("input_scale"); auto input_scale = out_op_info->GetInputScale<float>(out_var_node_name);
if (!found) { if (!found) {
found = true; found = true;
output_scale = input_scale; output_scale = input_scale;
...@@ -55,16 +68,22 @@ void QuantizedOpAttributesInferencePass::Apply( ...@@ -55,16 +68,22 @@ void QuantizedOpAttributesInferencePass::Apply(
CHECK_EQ(output_scale, input_scale); CHECK_EQ(output_scale, input_scale);
} }
} }
if (found) {
inst.mutable_op_info()->SetOutputScale(out_var_node_name, output_scale);
} else if (op_info->HasAttr("out_threshold")) {
// Only consider one output, there are only one out_threshold
int bit_length = op_info->GetAttr<int>("bit_length");
int range = (1 << (bit_length - 1)) - 1;
output_scale = op_info->GetAttr<float>("out_threshold");
inst.mutable_op_info()->SetOutputScale(out_var_node_name,
output_scale / range);
} else {
is_quantized = false;
}
} }
if (found) {
inst.mutable_op_info()->SetAttr("output_scale", output_scale); // Fix the missing of the attribute 'enable_int8'.
} else if (op_info->HasAttr("output_scale")) { if (is_quantized) {
int bit_length = op_info->GetAttr<int>("bit_length");
int range = (1 << (bit_length - 1)) - 1;
output_scale = op_info->GetAttr<float>("output_scale");
inst.mutable_op_info()->SetAttr("output_scale", output_scale / range);
}
if (op_info->HasAttr("output_scale")) {
inst.mutable_op_info()->SetAttr("enable_int8", true); inst.mutable_op_info()->SetAttr("enable_int8", true);
} }
} }
......
...@@ -110,15 +110,16 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -110,15 +110,16 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if (out_type_int8) { if (out_type_int8) {
auto out_node = node.outlinks.front(); auto out_node = node.outlinks.front();
CHECK(out_node->IsArg()); CHECK(out_node->IsArg());
auto out_node_name = out_node->arg()->name;
auto one_adj_op_node = out_node->outlinks.front(); auto one_adj_op_node = out_node->outlinks.front();
CHECK(one_adj_op_node->IsStmt()); CHECK(one_adj_op_node->IsStmt());
auto& one_adj_instruct = one_adj_op_node->AsStmt(); auto& one_adj_instruct = one_adj_op_node->AsStmt();
CHECK(one_adj_instruct.op_info()->HasAttr("enable_int8")); CHECK(one_adj_instruct.op_info()->HasAttr("enable_int8"));
CHECK(one_adj_instruct.op_info()->HasAttr("input_scale")); CHECK(one_adj_instruct.op_info()->HasInputScale(out_node_name));
instruct.mutable_op_info()->SetAttr( instruct.mutable_op_info()->SetOutputScale(
"output_scale", out_node_name,
one_adj_instruct.op_info()->GetAttr<float>("input_scale")); one_adj_instruct.op_info()->GetInputScale<float>(out_node_name));
auto update_desc = *instruct.mutable_op_info(); auto update_desc = *instruct.mutable_op_info();
instruct.ResetOp(update_desc, graph->valid_places()); instruct.ResetOp(update_desc, graph->valid_places());
......
...@@ -457,21 +457,23 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, ...@@ -457,21 +457,23 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
std::vector<float> input_data_scales; std::vector<float> input_data_scales;
std::vector<float> output_data_scales; std::vector<float> output_data_scales;
for (auto &var_node : input_var_nodes) { for (auto &var_node : input_var_nodes) {
auto var_node_name = var_node->arg()->name;
auto any_op_node = var_node->outlinks.front(); auto any_op_node = var_node->outlinks.front();
CHECK(any_op_node->IsStmt()); CHECK(any_op_node->IsStmt());
auto &any_inst = any_op_node->AsStmt(); auto &any_inst = any_op_node->AsStmt();
if (any_inst.op_info()->HasAttr("input_scale")) { if (any_inst.op_info()->HasInputScale(var_node_name)) {
input_data_scales.push_back( input_data_scales.push_back(
any_inst.op_info()->GetAttr<float>("input_scale")); any_inst.op_info()->GetInputScale<float>(var_node_name));
} }
} }
for (auto &var_node : output_var_nodes) { for (auto &var_node : output_var_nodes) {
auto var_node_name = var_node->arg()->name;
auto any_op_node = var_node->inlinks.front(); auto any_op_node = var_node->inlinks.front();
CHECK(any_op_node->IsStmt()); CHECK(any_op_node->IsStmt());
auto &any_inst = any_op_node->AsStmt(); auto &any_inst = any_op_node->AsStmt();
if (any_inst.op_info()->HasAttr("output_scale")) { if (any_inst.op_info()->HasOutputScale(var_node_name)) {
output_data_scales.push_back( output_data_scales.push_back(
any_inst.op_info()->GetAttr<float>("output_scale")); any_inst.op_info()->GetOutputScale<float>(var_node_name));
} }
} }
if (input_data_scales.size() > 0) { if (input_data_scales.size() > 0) {
......
...@@ -107,8 +107,8 @@ static bool InferScale(Node* var_node, Node* op_node, float* scale) { ...@@ -107,8 +107,8 @@ static bool InferScale(Node* var_node, Node* op_node, float* scale) {
if (op_type == "subgraph") { if (op_type == "subgraph") {
found = InferScaleFromSubgraph(var_name, op_info, scale, false); found = InferScaleFromSubgraph(var_name, op_info, scale, false);
} else { } else {
if (op_info->HasAttr("input_scale")) { if (op_info->HasInputScale(var_name)) {
*scale = op_info->GetAttr<float>("input_scale"); *scale = op_info->GetInputScale<float>(var_name);
found = true; found = true;
} else { } else {
// Obtain the output_scale from one of its previous Ops // Obtain the output_scale from one of its previous Ops
...@@ -120,8 +120,8 @@ static bool InferScale(Node* var_node, Node* op_node, float* scale) { ...@@ -120,8 +120,8 @@ static bool InferScale(Node* var_node, Node* op_node, float* scale) {
if (prev_op_type == "subgraph") { if (prev_op_type == "subgraph") {
found = InferScaleFromSubgraph(var_name, prev_op_info, scale, true); found = InferScaleFromSubgraph(var_name, prev_op_info, scale, true);
} else { } else {
if (prev_op_info->HasAttr("output_scale")) { if (prev_op_info->HasOutputScale(var_name)) {
*scale = prev_op_info->GetAttr<float>("output_scale"); *scale = prev_op_info->GetOutputScale<float>(var_name);
found = true; found = true;
} }
} }
......
...@@ -22,6 +22,14 @@ ...@@ -22,6 +22,14 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
std::string int2string(int index) {
const int BUFFER_LENGTH = 30;
char buffer[BUFFER_LENGTH];
int num = snprintf(buffer, sizeof(buffer), "%d", index);
CHECK(num > 0 && num < sizeof(buffer));
return std::string(buffer);
}
bool OpLite::InferShape() { bool OpLite::InferShape() {
// if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_ // if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_
// InferShapeByMemoryInternal will be applied. // InferShapeByMemoryInternal will be applied.
...@@ -186,5 +194,115 @@ void OpLite::AttachOutput(const cpp::OpDesc &op_desc, ...@@ -186,5 +194,115 @@ void OpLite::AttachOutput(const cpp::OpDesc &op_desc,
} }
} }
bool OpInfo::GetInputArgname(const std::string &value_name,
std::string *out) const {
for (auto &item : inputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
bool OpInfo::GetOutputArgname(const std::string &value_name,
std::string *out) const {
for (auto &item : outputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
bool OpInfo::GetInputIndex(const std::string &input_name, int *out) const {
for (auto &item : inputs_) {
auto it = std::find(item.second.begin(), item.second.end(), input_name);
if (it != item.second.end()) {
*out = it - item.second.begin();
return true;
}
}
return false;
}
bool OpInfo::GetOutputIndex(const std::string &output_name, int *out) const {
for (auto &item : outputs_) {
auto it = std::find(item.second.begin(), item.second.end(), output_name);
if (it != item.second.end()) {
*out = it - item.second.begin();
return true;
}
}
return false;
}
bool OpInfo::HasInputScale(const std::string &input_name) const {
std::string argname;
int index;
if (GetInputArgname(input_name, &argname) &&
GetInputIndex(input_name, &index)) {
return HasAttr(argname + int2string(index) + "_scale");
} else {
return false;
}
}
bool OpInfo::HasOutputScale(const std::string &output_name) const {
std::string argname;
int index;
if (GetOutputArgname(output_name, &argname) &&
GetOutputIndex(output_name, &index)) {
return HasAttr(argname + int2string(index) + "_scale");
} else {
return false;
}
}
template <>
void OpInfo::SetInputScale(const std::string &input_name,
const float &scale_value) {
std::string argname;
int index;
CHECK(GetInputArgname(input_name, &argname));
CHECK(GetInputIndex(input_name, &index));
SetAttr<float>(argname + int2string(index) + "_scale", scale_value);
}
template <>
void OpInfo::SetInputScale(const std::string &input_name,
const std::vector<float> &scale_value) {
std::string argname;
int index;
CHECK(GetInputArgname(input_name, &argname));
CHECK(GetInputIndex(input_name, &index));
SetAttr<std::vector<float>>(argname + int2string(index) + "_scale",
scale_value);
}
template <>
void OpInfo::SetOutputScale(const std::string &output_name,
const float &scale_value) {
std::string argname;
int index;
CHECK(GetOutputArgname(output_name, &argname));
CHECK(GetOutputIndex(output_name, &index));
SetAttr<float>(argname + int2string(index) + "_scale", scale_value);
}
template <>
void OpInfo::SetOutputScale(const std::string &output_name,
const std::vector<float> &scale_value) {
std::string argname;
int index;
CHECK(GetOutputArgname(output_name, &argname));
CHECK(GetOutputIndex(output_name, &index));
SetAttr<std::vector<float>>(argname + int2string(index) + "_scale",
scale_value);
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -30,6 +30,8 @@ ...@@ -30,6 +30,8 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
std::string int2string(int index);
// For registry factory. // For registry factory.
struct Registry { struct Registry {
void Touch() {} void Touch() {}
...@@ -229,51 +231,36 @@ class OpInfo : public cpp::OpDesc { ...@@ -229,51 +231,36 @@ class OpInfo : public cpp::OpDesc {
return OutputArgumentNames(); return OutputArgumentNames();
} }
bool GetInputArgname(const std::string &value_name, std::string *out) const { bool GetInputArgname(const std::string &value_name, std::string *out) const;
for (auto &item : inputs_) { bool GetOutputArgname(const std::string &value_name, std::string *out) const;
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
bool GetOutputArgname(const std::string &value_name, std::string *out) const {
for (auto &item : outputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
// For the input variable name, find the index of the corresponding bool GetInputIndex(const std::string &input_name, int *out) const;
// input argname bool GetOutputIndex(const std::string &output_name, int *out) const;
bool GetInputIndex(const std::string &value_name, int *out) const {
for (auto &item : inputs_) { bool HasInputScale(const std::string &input_name) const;
auto it = std::find(item.second.begin(), item.second.end(), value_name); bool HasOutputScale(const std::string &output_name) const;
if (it != item.second.end()) {
*out = it - item.second.begin(); template <typename T>
return true; void SetInputScale(const std::string &input_name, const T &scale_value);
} template <typename T>
} void SetOutputScale(const std::string &output_name, const T &scale_value);
return false;
template <typename T>
T GetInputScale(const std::string &input_name) const {
std::string argname;
int index;
CHECK(GetInputArgname(input_name, &argname));
CHECK(GetInputIndex(input_name, &index));
return GetAttr<T>(argname + int2string(index) + "_scale");
} }
// For the output variable name, find the index of the corresponding template <typename T>
// output argname T GetOutputScale(const std::string &output_name) const {
bool GetOutputIndex(const std::string &value_name, int *out) const { std::string argname;
for (auto &item : outputs_) { int index;
auto it = std::find(item.second.begin(), item.second.end(), value_name); CHECK(GetOutputArgname(output_name, &argname));
if (it != item.second.end()) { CHECK(GetOutputIndex(output_name, &index));
*out = it - item.second.begin(); return GetAttr<T>(argname + int2string(index) + "_scale");
return true;
}
}
return false;
} }
void UpdateAllInputs(const std::string &from, const std::string &to) { void UpdateAllInputs(const std::string &from, const std::string &to) {
......
...@@ -99,12 +99,16 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -99,12 +99,16 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
std::vector<float> weight_scale; std::vector<float> weight_scale;
if (op_info->HasAttr("enable_int8")) { if (op_info->HasAttr("enable_int8")) {
if (op_info->GetAttr<bool>("enable_int8")) { if (op_info->GetAttr<bool>("enable_int8")) {
if (op_info->HasAttr("input_scale")) auto input_name = op_info->Input("Input").front();
input_scale = op_info->GetAttr<float>("input_scale"); auto filter_name = op_info->Input("Filter").front();
if (op_info->HasAttr("weight_scale")) auto output_name = op_info->Output("Output").front();
weight_scale = op_info->GetAttr<std::vector<float>>("weight_scale"); if (op_info->HasInputScale(input_name))
if (op_info->HasAttr("output_scale")) input_scale = op_info->GetInputScale<float>(input_name);
output_scale = op_info->GetAttr<float>("output_scale"); if (op_info->HasInputScale(filter_name))
weight_scale = op_info->GetInputScale<std::vector<float>>(filter_name);
if (op_info->HasOutputScale(output_name)) {
output_scale = op_info->GetOutputScale<float>(output_name);
}
VLOG(3) << "has output scale:" << output_scale; VLOG(3) << "has output scale:" << output_scale;
} else { } else {
return FAILED; return FAILED;
......
...@@ -57,12 +57,15 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -57,12 +57,15 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
std::vector<float> w_scale; std::vector<float> w_scale;
if (op_info->HasAttr("enable_int8")) { if (op_info->HasAttr("enable_int8")) {
if (op_info->GetAttr<bool>("enable_int8")) { if (op_info->GetAttr<bool>("enable_int8")) {
if (op_info->HasAttr("input_scale")) auto input_name = op_info->Input("Input").front();
input_scale = op_info->GetAttr<float>("input_scale"); auto weight_name = op_info->Input("W").front();
if (op_info->HasAttr("weight_scale")) auto out_name = op_info->Output("Out").front();
w_scale = op_info->GetAttr<std::vector<float>>("weight_scale"); if (op_info->HasInputScale(input_name))
if (op_info->HasAttr("output_scale")) input_scale = op_info->GetInputScale<float>(input_name);
out_scale = op_info->GetAttr<float>("output_scale"); if (op_info->HasInputScale(weight_name))
w_scale = op_info->GetInputScale<std::vector<float>>(weight_name);
if (op_info->HasOutputScale(out_name))
out_scale = op_info->GetOutputScale<float>(out_name);
} else { } else {
return FAILED; return FAILED;
} }
......
...@@ -91,10 +91,12 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -91,10 +91,12 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
float out_scale = 1.0f; float out_scale = 1.0f;
if (op_info->HasAttr("enable_int8")) { if (op_info->HasAttr("enable_int8")) {
if (op_info->GetAttr<bool>("enable_int8")) { if (op_info->GetAttr<bool>("enable_int8")) {
if (op_info->HasAttr("input_scale")) auto x_name = op_info->Input("X").front();
x_scale = op_info->GetAttr<float>("input_scale"); auto out_name = op_info->Output("Out").front();
if (op_info->HasAttr("output_scale")) if (op_info->HasInputScale(x_name))
out_scale = op_info->GetAttr<float>("output_scale"); x_scale = op_info->GetInputScale<float>(x_name);
if (op_info->HasOutputScale(out_name))
out_scale = op_info->GetOutputScale<float>(out_name);
} else { } else {
LOG(WARNING) << "Do not enable_int8"; LOG(WARNING) << "Do not enable_int8";
return FAILED; return FAILED;
......
...@@ -49,10 +49,12 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -49,10 +49,12 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
float out_scale = 1.0f; float out_scale = 1.0f;
if (op_info->HasAttr("enable_int8")) { if (op_info->HasAttr("enable_int8")) {
if (op_info->GetAttr<bool>("enable_int8")) { if (op_info->GetAttr<bool>("enable_int8")) {
if (op_info->HasAttr("input_scale")) auto x_name = op_info->Input("X").front();
input_scale = op_info->GetAttr<float>("input_scale"); auto out_name = op_info->Output("Out").front();
if (op_info->HasAttr("output_scale")) if (op_info->HasInputScale(x_name))
out_scale = op_info->GetAttr<float>("output_scale"); input_scale = op_info->GetInputScale<float>(x_name);
if (op_info->HasOutputScale(out_name))
out_scale = op_info->GetOutputScale<float>(out_name);
} else { } else {
LOG(WARNING) << "Do not enable_int8"; LOG(WARNING) << "Do not enable_int8";
return FAILED; return FAILED;
......
...@@ -130,15 +130,19 @@ class ConvOpLite : public OpLite { ...@@ -130,15 +130,19 @@ class ConvOpLite : public OpLite {
padding_algorithm_ = op_desc.GetAttr<std::string>("padding_algorithm"); padding_algorithm_ = op_desc.GetAttr<std::string>("padding_algorithm");
} }
// For Int8 // For Int8
if (op_desc.HasAttr("enable_int8")) { const OpInfo* op_info = dynamic_cast<const OpInfo*>(&op_desc);
param_.enable_int8 = op_desc.GetAttr<bool>("enable_int8"); if (op_info != nullptr && op_info->HasAttr("enable_int8")) {
if (op_desc.HasAttr("input_scale")) param_.enable_int8 = op_info->GetAttr<bool>("enable_int8");
param_.input_scale = op_desc.GetAttr<float>("input_scale"); auto input_name = op_info->Input("Input").front();
if (op_desc.HasAttr("weight_scale")) auto filter_name = op_info->Input("Filter").front();
auto output_name = op_info->Output("Output").front();
if (op_info->HasInputScale(input_name))
param_.input_scale = op_info->GetInputScale<float>(input_name);
if (op_info->HasInputScale(filter_name))
param_.weight_scale = param_.weight_scale =
op_desc.GetAttr<std::vector<float>>("weight_scale"); op_info->GetInputScale<std::vector<float>>(filter_name);
if (op_desc.HasAttr("output_scale")) { if (op_info->HasOutputScale(output_name)) {
param_.output_scale = op_desc.GetAttr<float>("output_scale"); param_.output_scale = op_info->GetOutputScale<float>(output_name);
} }
} }
......
...@@ -102,14 +102,19 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { ...@@ -102,14 +102,19 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
} }
// For Int8 // For Int8
if (op_desc.HasAttr("enable_int8")) { const OpInfo* op_info = dynamic_cast<const OpInfo*>(&op_desc);
param_.enable_int8 = op_desc.GetAttr<bool>("enable_int8"); if (op_info != nullptr && op_info->HasAttr("enable_int8")) {
if (op_desc.HasAttr("input_scale")) param_.enable_int8 = op_info->GetAttr<bool>("enable_int8");
param_.input_scale = op_desc.GetAttr<float>("input_scale"); auto input_name = op_info->Input("Input").front();
if (op_desc.HasAttr("weight_scale")) auto weight_name = op_info->Input("W").front();
param_.weight_scale = op_desc.GetAttr<std::vector<float>>("weight_scale"); auto out_name = op_info->Output("Out").front();
if (op_desc.HasAttr("output_scale")) if (op_info->HasInputScale(input_name))
param_.output_scale = op_desc.GetAttr<float>("output_scale"); param_.input_scale = op_info->GetInputScale<float>(input_name);
if (op_info->HasInputScale(weight_name))
param_.weight_scale =
op_info->GetInputScale<std::vector<float>>(weight_name);
if (op_info->HasOutputScale(out_name))
param_.output_scale = op_info->GetOutputScale<float>(out_name);
} }
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册