未验证 提交 ee94db86 编写于 作者: C cc 提交者: GitHub

[Core ] Update quantization: save scales in op attrs by <inputname_index, scale_value> (#3816)

* Update quantization, scale save in op attrs by  <inputname_index, scale_value>, test=develop
Co-authored-by: Nhong19860320 <9973393+hong19860320@users.noreply.github.com>
上级 39b78a4c
......@@ -121,7 +121,7 @@ lite_cc_library(kernel SRCS kernel.cc
PROFILE_DEPS lite_profiler
)
lite_cc_library(op SRCS op_lite.cc DEPS scope op_registry target_wrapper kernel
cpp_op_desc tensor
cpp_op_desc tensor utils
)
add_dependencies(kernel kernel_list_h)
......
......@@ -156,12 +156,13 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
// little difference for int8
///////////////////////////////////////////////////////////////////////////////
if (enable_int8) {
PADDLE_ENFORCE(conv_op_desc->HasAttr("weight_scale"),
std::string weight_name = conv_op_desc->Input("Filter").front();
PADDLE_ENFORCE(conv_op_desc->HasInputScale(weight_name),
"INT8 mode: Conv should has weight_scale attr");
auto conv_weight_d = conv_weight_t->mutable_data<int8_t>();
// compute new conv_weight for int8
auto weight_scale =
conv_op_desc->GetAttr<std::vector<float>>("weight_scale");
conv_op_desc->GetInputScale<std::vector<float>>(weight_name);
if (conv_type_ == "conv2d_transpose" && !depthwise) {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] *
conv_weight_t->dims()[3];
......@@ -188,7 +189,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
}
}
}
conv_op_desc->SetAttr("weight_scale", weight_scale);
conv_op_desc->SetInputScale(weight_name, weight_scale);
} else if (is_weight_quantization) {
std::string scale_name = conv_weight_name + "_quant_scale";
if (conv_op_desc->HasAttr(scale_name)) {
......
......@@ -71,7 +71,27 @@ void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
}
cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("mul")->stmt()->op_info();
auto op_desc = *matched.at("mul")->stmt()->op_info();
// Get the input scale from mul
float x_scale{};
std::vector<float> y_scale_vct;
auto y_var_node = matched.at("W")->arg();
auto input_x_name = op_desc.Input("X").front();
auto input_y_name = op_desc.Input("Y").front();
bool is_quantized_op = op_desc.HasInputScale(input_x_name) &&
op_desc.HasInputScale(input_y_name);
if (is_quantized_op) {
x_scale = op_desc.GetInputScale<float>(input_x_name);
if (y_var_node->is_weight) { // the scale of y is a vector
y_scale_vct =
op_desc.GetInputScale<std::vector<float>>(op_desc.Input("Y").front());
} else {
y_scale_vct.push_back( // the scale of y is scalar
op_desc.GetInputScale<float>(op_desc.Input("Y").front()));
}
}
op_desc.mutable_inputs()->clear();
op_desc.mutable_outputs()->clear();
op_desc.SetType("fc");
......@@ -85,6 +105,17 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
if (with_relu_) {
op_desc.SetAttr("activation_type", std::string{"relu"});
}
// Set the input scale into fc
if (is_quantized_op) {
op_desc.SetInputScale(matched.at("x")->arg()->name, x_scale);
if (y_var_node->is_weight) {
op_desc.SetInputScale(matched.at("W")->arg()->name, y_scale_vct);
} else {
op_desc.SetInputScale(matched.at("W")->arg()->name, y_scale_vct.front());
}
}
return op_desc;
}
......
......@@ -64,13 +64,7 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
for (auto* quantized_node : outlinks) {
// save input scale in quantized op by input argname + index
auto op_desc = *quantized_node->stmt()->mutable_op_info();
std::string argname;
int index;
op_desc.GetInputArgname(out_act_name, &argname);
op_desc.GetInputIndex(out_act_name, &index);
op_desc.SetAttr<float>(argname + std::to_string(index) + "_input_scale",
scale_value);
op_desc.SetAttr<float>("input_scale", scale_value); // save it for now
op_desc.SetInputScale(out_act_name, scale_value);
op_desc.SetAttr<int>("bit_length", bit_length);
op_desc.UpdateAllInputs(out_act_name, in_act_name);
quantized_node->stmt()->ResetOp(op_desc, graph->valid_places());
......@@ -135,6 +129,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
auto* quantized_op = matched.at("quantized_op");
auto* dequant_op = matched.at("dequant_op");
auto* dequant_op_out = matched.at("dequant_op_out");
auto weight_name = quantized_op_weight->arg()->name;
// obtain weight_scale from max_range
auto* scope = quantized_op->stmt()->op()->scope();
......@@ -150,7 +145,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
// = max(abs(weight)) / range
// set op desc
cpp::OpDesc op_desc = *quantized_op->stmt()->op_info();
auto op_desc = *quantized_op->stmt()->op_info();
auto quantized_weight_var_name = quantized_op_weight->arg()->name;
auto quantized_weight_t =
scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>();
......@@ -173,7 +168,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
weight_scale.push_back(whole_weight_scale);
}
op_desc.SetAttr("enable_int8", true);
op_desc.SetAttr("weight_scale", weight_scale);
op_desc.SetInputScale(weight_name, weight_scale);
// change the weight from the float type to int8 type.
Tensor temp_tensor;
......@@ -246,6 +241,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
auto* dequant_op_channel_scale = matched.at("dequant_op_channel_scale");
auto* dequant_op = matched.at("dequant_op");
auto* dequant_op_out = matched.at("dequant_op_out");
auto weight_name = quantized_op_weight->arg()->name;
// obtain input weight_scale from fake_dequant op
auto* scope = quantized_op->stmt()->op()->scope();
......@@ -265,7 +261,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
}
// set op desc
cpp::OpDesc op_desc = *quantized_op->stmt()->op_info();
auto op_desc = *quantized_op->stmt()->op_info();
if (quantized_op_type_ == "conv2d" ||
quantized_op_type_ == "depthwise_conv2d") {
op_desc.SetInput("Input", {quantized_op_input->arg()->name});
......@@ -275,7 +271,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
}
op_desc.SetAttr("enable_int8", true);
op_desc.SetAttr("weight_scale", weight_scale);
op_desc.SetInputScale(weight_name, weight_scale);
// change the weight from the float type to int8 type.
auto quantized_weight_var_name = quantized_op_weight->arg()->name;
......@@ -352,22 +348,7 @@ void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
// Save quantization info in op_info attr
auto op_info = *quantized_node->stmt()->op_info();
op_info.SetAttr<int>("bit_length", bit_length);
std::string argname;
int index;
op_info.GetInputArgname(output_act_name, &argname);
op_info.GetInputIndex(output_act_name, &index);
op_info.SetAttr<float>(argname + std::to_string(index) + "_input_scale",
scale_value);
std::string op_type = op_info.Type();
// Analyse the weight scale or input scale.
if (((op_type == "conv2d" || op_type == "depthwise_conv2d") &&
argname == "Input") ||
((op_type == "mul" || op_type == "matmul") && argname == "Y")) {
op_info.SetAttr<float>("weight_scale", scale_value);
} else {
op_info.SetAttr<float>("input_scale", scale_value);
}
op_info.SetInputScale(output_act_name, scale_value);
op_info.UpdateAllInputs(output_act_name, input_act_name);
quantized_node->stmt()->ResetOp(op_info, graph->valid_places());
......
......@@ -37,17 +37,30 @@ void QuantizedOpAttributesInferencePass::Apply(
auto& inst = op_node->AsStmt();
auto op_info = inst.op_info();
auto op_type = op_info->Type();
if (!op_info->HasAttr("input_scale")) continue;
bool found = false;
float output_scale;
// Check only if all of the inputs of the op have scale value
bool has_input_scale = true;
for (auto in_var_node : op_node->inlinks) {
CHECK(in_var_node->IsArg());
auto in_var_node_name = in_var_node->arg()->name;
has_input_scale &= op_info->HasInputScale(in_var_node_name);
}
if (!has_input_scale) continue;
// Infer the output scale according to its out_threshold or the input scale
// of its adjacent ops
bool is_quantized = true;
for (auto out_var_node : op_node->outlinks) {
CHECK(out_var_node->IsArg());
bool found = false;
float output_scale;
auto out_var_node_name = out_var_node->arg()->name;
for (auto out_op_node : out_var_node->outlinks) {
CHECK(out_op_node->IsStmt());
auto& out_inst = out_op_node->AsStmt();
auto out_op_info = out_inst.op_info();
if (!out_op_info->HasAttr("input_scale")) continue;
auto input_scale = out_op_info->GetAttr<float>("input_scale");
if (!out_op_info->HasInputScale(out_var_node_name)) continue;
auto input_scale = out_op_info->GetInputScale<float>(out_var_node_name);
if (!found) {
found = true;
output_scale = input_scale;
......@@ -55,16 +68,22 @@ void QuantizedOpAttributesInferencePass::Apply(
CHECK_EQ(output_scale, input_scale);
}
}
if (found) {
inst.mutable_op_info()->SetOutputScale(out_var_node_name, output_scale);
} else if (op_info->HasAttr("out_threshold")) {
// Only consider one output, there are only one out_threshold
int bit_length = op_info->GetAttr<int>("bit_length");
int range = (1 << (bit_length - 1)) - 1;
output_scale = op_info->GetAttr<float>("out_threshold");
inst.mutable_op_info()->SetOutputScale(out_var_node_name,
output_scale / range);
} else {
is_quantized = false;
}
}
if (found) {
inst.mutable_op_info()->SetAttr("output_scale", output_scale);
} else if (op_info->HasAttr("output_scale")) {
int bit_length = op_info->GetAttr<int>("bit_length");
int range = (1 << (bit_length - 1)) - 1;
output_scale = op_info->GetAttr<float>("output_scale");
inst.mutable_op_info()->SetAttr("output_scale", output_scale / range);
}
if (op_info->HasAttr("output_scale")) {
// Fix the missing of the attribute 'enable_int8'.
if (is_quantized) {
inst.mutable_op_info()->SetAttr("enable_int8", true);
}
}
......
......@@ -110,15 +110,16 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if (out_type_int8) {
auto out_node = node.outlinks.front();
CHECK(out_node->IsArg());
auto out_node_name = out_node->arg()->name;
auto one_adj_op_node = out_node->outlinks.front();
CHECK(one_adj_op_node->IsStmt());
auto& one_adj_instruct = one_adj_op_node->AsStmt();
CHECK(one_adj_instruct.op_info()->HasAttr("enable_int8"));
CHECK(one_adj_instruct.op_info()->HasAttr("input_scale"));
CHECK(one_adj_instruct.op_info()->HasInputScale(out_node_name));
instruct.mutable_op_info()->SetAttr(
"output_scale",
one_adj_instruct.op_info()->GetAttr<float>("input_scale"));
instruct.mutable_op_info()->SetOutputScale(
out_node_name,
one_adj_instruct.op_info()->GetInputScale<float>(out_node_name));
auto update_desc = *instruct.mutable_op_info();
instruct.ResetOp(update_desc, graph->valid_places());
......
......@@ -457,21 +457,23 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
std::vector<float> input_data_scales;
std::vector<float> output_data_scales;
for (auto &var_node : input_var_nodes) {
auto var_node_name = var_node->arg()->name;
auto any_op_node = var_node->outlinks.front();
CHECK(any_op_node->IsStmt());
auto &any_inst = any_op_node->AsStmt();
if (any_inst.op_info()->HasAttr("input_scale")) {
if (any_inst.op_info()->HasInputScale(var_node_name)) {
input_data_scales.push_back(
any_inst.op_info()->GetAttr<float>("input_scale"));
any_inst.op_info()->GetInputScale<float>(var_node_name));
}
}
for (auto &var_node : output_var_nodes) {
auto var_node_name = var_node->arg()->name;
auto any_op_node = var_node->inlinks.front();
CHECK(any_op_node->IsStmt());
auto &any_inst = any_op_node->AsStmt();
if (any_inst.op_info()->HasAttr("output_scale")) {
if (any_inst.op_info()->HasOutputScale(var_node_name)) {
output_data_scales.push_back(
any_inst.op_info()->GetAttr<float>("output_scale"));
any_inst.op_info()->GetOutputScale<float>(var_node_name));
}
}
if (input_data_scales.size() > 0) {
......
......@@ -107,8 +107,8 @@ static bool InferScale(Node* var_node, Node* op_node, float* scale) {
if (op_type == "subgraph") {
found = InferScaleFromSubgraph(var_name, op_info, scale, false);
} else {
if (op_info->HasAttr("input_scale")) {
*scale = op_info->GetAttr<float>("input_scale");
if (op_info->HasInputScale(var_name)) {
*scale = op_info->GetInputScale<float>(var_name);
found = true;
} else {
// Obtain the output_scale from one of its previous Ops
......@@ -120,8 +120,8 @@ static bool InferScale(Node* var_node, Node* op_node, float* scale) {
if (prev_op_type == "subgraph") {
found = InferScaleFromSubgraph(var_name, prev_op_info, scale, true);
} else {
if (prev_op_info->HasAttr("output_scale")) {
*scale = prev_op_info->GetAttr<float>("output_scale");
if (prev_op_info->HasOutputScale(var_name)) {
*scale = prev_op_info->GetOutputScale<float>(var_name);
found = true;
}
}
......
......@@ -22,6 +22,14 @@
namespace paddle {
namespace lite {
std::string int2string(int index) {
const int BUFFER_LENGTH = 30;
char buffer[BUFFER_LENGTH];
int num = snprintf(buffer, sizeof(buffer), "%d", index);
CHECK(num > 0 && num < sizeof(buffer));
return std::string(buffer);
}
bool OpLite::InferShape() {
// if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_
// InferShapeByMemoryInternal will be applied.
......@@ -186,5 +194,115 @@ void OpLite::AttachOutput(const cpp::OpDesc &op_desc,
}
}
bool OpInfo::GetInputArgname(const std::string &value_name,
std::string *out) const {
for (auto &item : inputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
bool OpInfo::GetOutputArgname(const std::string &value_name,
std::string *out) const {
for (auto &item : outputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
bool OpInfo::GetInputIndex(const std::string &input_name, int *out) const {
for (auto &item : inputs_) {
auto it = std::find(item.second.begin(), item.second.end(), input_name);
if (it != item.second.end()) {
*out = it - item.second.begin();
return true;
}
}
return false;
}
bool OpInfo::GetOutputIndex(const std::string &output_name, int *out) const {
for (auto &item : outputs_) {
auto it = std::find(item.second.begin(), item.second.end(), output_name);
if (it != item.second.end()) {
*out = it - item.second.begin();
return true;
}
}
return false;
}
bool OpInfo::HasInputScale(const std::string &input_name) const {
std::string argname;
int index;
if (GetInputArgname(input_name, &argname) &&
GetInputIndex(input_name, &index)) {
return HasAttr(argname + int2string(index) + "_scale");
} else {
return false;
}
}
bool OpInfo::HasOutputScale(const std::string &output_name) const {
std::string argname;
int index;
if (GetOutputArgname(output_name, &argname) &&
GetOutputIndex(output_name, &index)) {
return HasAttr(argname + int2string(index) + "_scale");
} else {
return false;
}
}
template <>
void OpInfo::SetInputScale(const std::string &input_name,
const float &scale_value) {
std::string argname;
int index;
CHECK(GetInputArgname(input_name, &argname));
CHECK(GetInputIndex(input_name, &index));
SetAttr<float>(argname + int2string(index) + "_scale", scale_value);
}
template <>
void OpInfo::SetInputScale(const std::string &input_name,
const std::vector<float> &scale_value) {
std::string argname;
int index;
CHECK(GetInputArgname(input_name, &argname));
CHECK(GetInputIndex(input_name, &index));
SetAttr<std::vector<float>>(argname + int2string(index) + "_scale",
scale_value);
}
template <>
void OpInfo::SetOutputScale(const std::string &output_name,
const float &scale_value) {
std::string argname;
int index;
CHECK(GetOutputArgname(output_name, &argname));
CHECK(GetOutputIndex(output_name, &index));
SetAttr<float>(argname + int2string(index) + "_scale", scale_value);
}
template <>
void OpInfo::SetOutputScale(const std::string &output_name,
const std::vector<float> &scale_value) {
std::string argname;
int index;
CHECK(GetOutputArgname(output_name, &argname));
CHECK(GetOutputIndex(output_name, &index));
SetAttr<std::vector<float>>(argname + int2string(index) + "_scale",
scale_value);
}
} // namespace lite
} // namespace paddle
......@@ -30,6 +30,8 @@
namespace paddle {
namespace lite {
std::string int2string(int index);
// For registry factory.
struct Registry {
void Touch() {}
......@@ -229,51 +231,36 @@ class OpInfo : public cpp::OpDesc {
return OutputArgumentNames();
}
bool GetInputArgname(const std::string &value_name, std::string *out) const {
for (auto &item : inputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
bool GetOutputArgname(const std::string &value_name, std::string *out) const {
for (auto &item : outputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
bool GetInputArgname(const std::string &value_name, std::string *out) const;
bool GetOutputArgname(const std::string &value_name, std::string *out) const;
// For the input variable name, find the index of the corresponding
// input argname
bool GetInputIndex(const std::string &value_name, int *out) const {
for (auto &item : inputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = it - item.second.begin();
return true;
}
}
return false;
bool GetInputIndex(const std::string &input_name, int *out) const;
bool GetOutputIndex(const std::string &output_name, int *out) const;
bool HasInputScale(const std::string &input_name) const;
bool HasOutputScale(const std::string &output_name) const;
template <typename T>
void SetInputScale(const std::string &input_name, const T &scale_value);
template <typename T>
void SetOutputScale(const std::string &output_name, const T &scale_value);
template <typename T>
T GetInputScale(const std::string &input_name) const {
std::string argname;
int index;
CHECK(GetInputArgname(input_name, &argname));
CHECK(GetInputIndex(input_name, &index));
return GetAttr<T>(argname + int2string(index) + "_scale");
}
// For the output variable name, find the index of the corresponding
// output argname
bool GetOutputIndex(const std::string &value_name, int *out) const {
for (auto &item : outputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = it - item.second.begin();
return true;
}
}
return false;
template <typename T>
T GetOutputScale(const std::string &output_name) const {
std::string argname;
int index;
CHECK(GetOutputArgname(output_name, &argname));
CHECK(GetOutputIndex(output_name, &index));
return GetAttr<T>(argname + int2string(index) + "_scale");
}
void UpdateAllInputs(const std::string &from, const std::string &to) {
......
......@@ -99,12 +99,16 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
std::vector<float> weight_scale;
if (op_info->HasAttr("enable_int8")) {
if (op_info->GetAttr<bool>("enable_int8")) {
if (op_info->HasAttr("input_scale"))
input_scale = op_info->GetAttr<float>("input_scale");
if (op_info->HasAttr("weight_scale"))
weight_scale = op_info->GetAttr<std::vector<float>>("weight_scale");
if (op_info->HasAttr("output_scale"))
output_scale = op_info->GetAttr<float>("output_scale");
auto input_name = op_info->Input("Input").front();
auto filter_name = op_info->Input("Filter").front();
auto output_name = op_info->Output("Output").front();
if (op_info->HasInputScale(input_name))
input_scale = op_info->GetInputScale<float>(input_name);
if (op_info->HasInputScale(filter_name))
weight_scale = op_info->GetInputScale<std::vector<float>>(filter_name);
if (op_info->HasOutputScale(output_name)) {
output_scale = op_info->GetOutputScale<float>(output_name);
}
VLOG(3) << "has output scale:" << output_scale;
} else {
return FAILED;
......
......@@ -57,12 +57,15 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
std::vector<float> w_scale;
if (op_info->HasAttr("enable_int8")) {
if (op_info->GetAttr<bool>("enable_int8")) {
if (op_info->HasAttr("input_scale"))
input_scale = op_info->GetAttr<float>("input_scale");
if (op_info->HasAttr("weight_scale"))
w_scale = op_info->GetAttr<std::vector<float>>("weight_scale");
if (op_info->HasAttr("output_scale"))
out_scale = op_info->GetAttr<float>("output_scale");
auto input_name = op_info->Input("Input").front();
auto weight_name = op_info->Input("W").front();
auto out_name = op_info->Output("Out").front();
if (op_info->HasInputScale(input_name))
input_scale = op_info->GetInputScale<float>(input_name);
if (op_info->HasInputScale(weight_name))
w_scale = op_info->GetInputScale<std::vector<float>>(weight_name);
if (op_info->HasOutputScale(out_name))
out_scale = op_info->GetOutputScale<float>(out_name);
} else {
return FAILED;
}
......
......@@ -91,10 +91,12 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
float out_scale = 1.0f;
if (op_info->HasAttr("enable_int8")) {
if (op_info->GetAttr<bool>("enable_int8")) {
if (op_info->HasAttr("input_scale"))
x_scale = op_info->GetAttr<float>("input_scale");
if (op_info->HasAttr("output_scale"))
out_scale = op_info->GetAttr<float>("output_scale");
auto x_name = op_info->Input("X").front();
auto out_name = op_info->Output("Out").front();
if (op_info->HasInputScale(x_name))
x_scale = op_info->GetInputScale<float>(x_name);
if (op_info->HasOutputScale(out_name))
out_scale = op_info->GetOutputScale<float>(out_name);
} else {
LOG(WARNING) << "Do not enable_int8";
return FAILED;
......
......@@ -49,10 +49,12 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
float out_scale = 1.0f;
if (op_info->HasAttr("enable_int8")) {
if (op_info->GetAttr<bool>("enable_int8")) {
if (op_info->HasAttr("input_scale"))
input_scale = op_info->GetAttr<float>("input_scale");
if (op_info->HasAttr("output_scale"))
out_scale = op_info->GetAttr<float>("output_scale");
auto x_name = op_info->Input("X").front();
auto out_name = op_info->Output("Out").front();
if (op_info->HasInputScale(x_name))
input_scale = op_info->GetInputScale<float>(x_name);
if (op_info->HasOutputScale(out_name))
out_scale = op_info->GetOutputScale<float>(out_name);
} else {
LOG(WARNING) << "Do not enable_int8";
return FAILED;
......
......@@ -130,15 +130,19 @@ class ConvOpLite : public OpLite {
padding_algorithm_ = op_desc.GetAttr<std::string>("padding_algorithm");
}
// For Int8
if (op_desc.HasAttr("enable_int8")) {
param_.enable_int8 = op_desc.GetAttr<bool>("enable_int8");
if (op_desc.HasAttr("input_scale"))
param_.input_scale = op_desc.GetAttr<float>("input_scale");
if (op_desc.HasAttr("weight_scale"))
const OpInfo* op_info = dynamic_cast<const OpInfo*>(&op_desc);
if (op_info != nullptr && op_info->HasAttr("enable_int8")) {
param_.enable_int8 = op_info->GetAttr<bool>("enable_int8");
auto input_name = op_info->Input("Input").front();
auto filter_name = op_info->Input("Filter").front();
auto output_name = op_info->Output("Output").front();
if (op_info->HasInputScale(input_name))
param_.input_scale = op_info->GetInputScale<float>(input_name);
if (op_info->HasInputScale(filter_name))
param_.weight_scale =
op_desc.GetAttr<std::vector<float>>("weight_scale");
if (op_desc.HasAttr("output_scale")) {
param_.output_scale = op_desc.GetAttr<float>("output_scale");
op_info->GetInputScale<std::vector<float>>(filter_name);
if (op_info->HasOutputScale(output_name)) {
param_.output_scale = op_info->GetOutputScale<float>(output_name);
}
}
......
......@@ -102,14 +102,19 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
}
// For Int8
if (op_desc.HasAttr("enable_int8")) {
param_.enable_int8 = op_desc.GetAttr<bool>("enable_int8");
if (op_desc.HasAttr("input_scale"))
param_.input_scale = op_desc.GetAttr<float>("input_scale");
if (op_desc.HasAttr("weight_scale"))
param_.weight_scale = op_desc.GetAttr<std::vector<float>>("weight_scale");
if (op_desc.HasAttr("output_scale"))
param_.output_scale = op_desc.GetAttr<float>("output_scale");
const OpInfo* op_info = dynamic_cast<const OpInfo*>(&op_desc);
if (op_info != nullptr && op_info->HasAttr("enable_int8")) {
param_.enable_int8 = op_info->GetAttr<bool>("enable_int8");
auto input_name = op_info->Input("Input").front();
auto weight_name = op_info->Input("W").front();
auto out_name = op_info->Output("Out").front();
if (op_info->HasInputScale(input_name))
param_.input_scale = op_info->GetInputScale<float>(input_name);
if (op_info->HasInputScale(weight_name))
param_.weight_scale =
op_info->GetInputScale<std::vector<float>>(weight_name);
if (op_info->HasOutputScale(out_name))
param_.output_scale = op_info->GetOutputScale<float>(out_name);
}
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册