未验证 提交 eca4dcbb 编写于 作者: C cc 提交者: GitHub

Optimize obtaining quantized scale in inference stage, test=develop (#4308)

* Optimize obtaining quantized scale in inference, test=develop
上级 e45c4242
...@@ -233,67 +233,97 @@ bool OpInfo::GetOutputIndex(const std::string &output_name, int *out) const { ...@@ -233,67 +233,97 @@ bool OpInfo::GetOutputIndex(const std::string &output_name, int *out) const {
return false; return false;
} }
bool OpInfo::HasInputScale(const std::string &input_name) const { bool OpInfo::HasInputScale(const std::string &name, bool is_scale_name) const {
std::string argname; bool res = false;
int index; if (is_scale_name) {
if (GetInputArgname(input_name, &argname) && res = HasAttr(name);
GetInputIndex(input_name, &index)) {
return HasAttr(argname + to_string(index) + "_scale");
} else { } else {
return false; std::string argname;
int index;
if (GetInputArgname(name, &argname) && GetInputIndex(name, &index)) {
res = HasAttr(argname + to_string(index) + "_scale");
}
} }
return res;
} }
bool OpInfo::HasOutputScale(const std::string &output_name) const { bool OpInfo::HasOutputScale(const std::string &name, bool is_scale_name) const {
std::string argname; bool res = false;
int index; if (is_scale_name) {
if (GetOutputArgname(output_name, &argname) && res = HasAttr(name);
GetOutputIndex(output_name, &index)) {
return HasAttr(argname + to_string(index) + "_scale");
} else { } else {
return false; std::string argname;
int index;
if (GetOutputArgname(name, &argname) && GetOutputIndex(name, &index)) {
res = HasAttr(argname + to_string(index) + "_scale");
}
} }
return res;
} }
void OpInfo::SetInputScale(const std::string &input_name, void OpInfo::SetInputScale(const std::string &name,
const std::vector<float> &scale_value) { const std::vector<float> &scale_value,
std::string argname; bool is_scale_name) {
int index; std::string scale_name;
CHECK(GetInputArgname(input_name, &argname)); if (is_scale_name) {
CHECK(GetInputIndex(input_name, &index)); scale_name = name;
CHECK(scale_value.size() > 0) } else {
<< "Error in SetInputScale: the scales should not be empty"; std::string argname;
SetAttr<std::vector<float>>(argname + to_string(index) + "_scale", int index;
scale_value); CHECK(GetInputArgname(name, &argname));
CHECK(GetInputIndex(name, &index));
CHECK(scale_value.size() > 0)
<< "Error in SetInputScale: the scales should not be empty";
scale_name = argname + to_string(index) + "_scale";
}
SetAttr<std::vector<float>>(scale_name, scale_value);
} }
void OpInfo::SetOutputScale(const std::string &output_name, void OpInfo::SetOutputScale(const std::string &name,
const std::vector<float> &scale_value) { const std::vector<float> &scale_value,
std::string argname; bool is_scale_name) {
int index; std::string scale_name;
CHECK(GetOutputArgname(output_name, &argname)); if (is_scale_name) {
CHECK(GetOutputIndex(output_name, &index)); scale_name = name;
CHECK(scale_value.size() > 0) } else {
<< "Error in SetOutputScale: the scales should not be empty"; std::string argname;
SetAttr<std::vector<float>>(argname + to_string(index) + "_scale", int index;
scale_value); CHECK(GetOutputArgname(name, &argname));
CHECK(GetOutputIndex(name, &index));
CHECK(scale_value.size() > 0)
<< "Error in SetOutputScale: the scales should not be empty";
scale_name = argname + to_string(index) + "_scale";
}
SetAttr<std::vector<float>>(scale_name, scale_value);
} }
std::vector<float> OpInfo::GetInputScale(const std::string &input_name) const { std::vector<float> OpInfo::GetInputScale(const std::string &name,
std::string argname; bool is_scale_name) const {
int index; std::string scale_name;
CHECK(GetInputArgname(input_name, &argname)); if (is_scale_name) {
CHECK(GetInputIndex(input_name, &index)); scale_name = name;
return GetAttr<std::vector<float>>(argname + to_string(index) + "_scale"); } else {
std::string argname;
int index;
CHECK(GetInputArgname(name, &argname));
CHECK(GetInputIndex(name, &index));
scale_name = argname + to_string(index) + "_scale";
}
return GetAttr<std::vector<float>>(scale_name);
} }
std::vector<float> OpInfo::GetOutputScale( std::vector<float> OpInfo::GetOutputScale(const std::string &name,
const std::string &output_name) const { bool is_scale_name) const {
std::string argname; std::string scale_name;
int index; if (is_scale_name) {
CHECK(GetOutputArgname(output_name, &argname)); scale_name = name;
CHECK(GetOutputIndex(output_name, &index)); } else {
return GetAttr<std::vector<float>>(argname + to_string(index) + "_scale"); std::string argname;
int index;
CHECK(GetOutputArgname(name, &argname));
CHECK(GetOutputIndex(name, &index));
}
return GetAttr<std::vector<float>>(scale_name);
} }
} // namespace lite } // namespace lite
......
...@@ -251,19 +251,31 @@ class OpInfo : public cpp::OpDesc { ...@@ -251,19 +251,31 @@ class OpInfo : public cpp::OpDesc {
bool GetInputIndex(const std::string &input_name, int *out) const; bool GetInputIndex(const std::string &input_name, int *out) const;
bool GetOutputIndex(const std::string &output_name, int *out) const; bool GetOutputIndex(const std::string &output_name, int *out) const;
bool HasInputScale(const std::string &input_name) const; // If a quantized op has two input argname (X, Y) and one output
bool HasOutputScale(const std::string &output_name) const; // argname (Out). The scales of input argname X are saved in op desc as
// (X0_scale, scale_value_0), (X1_scale, scale_value_1)...
// The following APIs get or set the quantized scale in op_desc.
// If use the input or output name, the is_scale_name should be false.
// If use the scale_name such as (X0_scale, scale_value_0),
// the is_scale_name should be true.
bool HasInputScale(const std::string &name, bool is_scale_name = false) const;
bool HasOutputScale(const std::string &name,
bool is_scale_name = false) const;
void SetInputScale(const std::string &input_name, void SetInputScale(const std::string &input_name,
const std::vector<float> &scale_value); const std::vector<float> &scale_value,
bool is_scale_name = false);
void SetOutputScale(const std::string &output_name, void SetOutputScale(const std::string &output_name,
const std::vector<float> &scale_value); const std::vector<float> &scale_value,
bool is_scale_name = false);
// For conv2d, depthwise_conv2d and mul, the scale of weight are a vector. // For conv2d, depthwise_conv2d and mul, the scale of weight are a vector.
// Otherwise, all input and output scales are scalar, but we save these // Otherwise, all input and output scales are scalar, but we save these
// as vecotr. // as vecotr.
std::vector<float> GetInputScale(const std::string &input_name) const; std::vector<float> GetInputScale(const std::string &name,
std::vector<float> GetOutputScale(const std::string &output_name) const; bool is_scale_name = false) const;
std::vector<float> GetOutputScale(const std::string &name,
bool is_scale_name = false) const;
}; };
} // namespace lite } // namespace lite
......
...@@ -133,15 +133,16 @@ class ConvOpLite : public OpLite { ...@@ -133,15 +133,16 @@ class ConvOpLite : public OpLite {
const OpInfo* op_info = dynamic_cast<const OpInfo*>(&op_desc); const OpInfo* op_info = dynamic_cast<const OpInfo*>(&op_desc);
if (op_info != nullptr && op_info->HasAttr("enable_int8")) { if (op_info != nullptr && op_info->HasAttr("enable_int8")) {
param_.enable_int8 = op_info->GetAttr<bool>("enable_int8"); param_.enable_int8 = op_info->GetAttr<bool>("enable_int8");
auto input_name = op_info->Input("Input").front(); auto input_scale_name = "Input0_scale";
auto filter_name = op_info->Input("Filter").front(); auto filter_scale_name = "Filter0_scale";
auto output_name = op_info->Output("Output").front(); auto output_scale_name = "Output0_scale";
if (op_info->HasInputScale(input_name)) if (op_info->HasInputScale(input_scale_name, true))
param_.input_scale = op_info->GetInputScale(input_name)[0]; param_.input_scale = op_info->GetInputScale(input_scale_name, true)[0];
if (op_info->HasInputScale(filter_name)) if (op_info->HasInputScale(filter_scale_name, true))
param_.weight_scale = op_info->GetInputScale(filter_name); param_.weight_scale = op_info->GetInputScale(filter_scale_name, true);
if (op_info->HasOutputScale(output_name)) { if (op_info->HasOutputScale(output_scale_name, true)) {
param_.output_scale = op_info->GetOutputScale(output_name)[0]; param_.output_scale =
op_info->GetOutputScale(output_scale_name, true)[0];
} }
} }
......
...@@ -112,15 +112,15 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { ...@@ -112,15 +112,15 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
const OpInfo* op_info = dynamic_cast<const OpInfo*>(&op_desc); const OpInfo* op_info = dynamic_cast<const OpInfo*>(&op_desc);
if (op_info != nullptr && op_info->HasAttr("enable_int8")) { if (op_info != nullptr && op_info->HasAttr("enable_int8")) {
param_.enable_int8 = op_info->GetAttr<bool>("enable_int8"); param_.enable_int8 = op_info->GetAttr<bool>("enable_int8");
auto input_name = op_info->Input("Input").front(); auto input_scale_name = "Input0_scale";
auto weight_name = op_info->Input("W").front(); auto weight_scale_name = "W0_scale";
auto out_name = op_info->Output("Out").front(); auto out_scale_name = "Out0_scale";
if (op_info->HasInputScale(input_name)) if (op_info->HasInputScale(input_scale_name, true))
param_.input_scale = op_info->GetInputScale(input_name)[0]; param_.input_scale = op_info->GetInputScale(input_scale_name, true)[0];
if (op_info->HasInputScale(weight_name)) if (op_info->HasInputScale(weight_scale_name, true))
param_.weight_scale = op_info->GetInputScale(weight_name); param_.weight_scale = op_info->GetInputScale(weight_scale_name, true);
if (op_info->HasOutputScale(out_name)) if (op_info->HasOutputScale(out_scale_name, true))
param_.output_scale = op_info->GetOutputScale(out_name)[0]; param_.output_scale = op_info->GetOutputScale(out_scale_name, true)[0];
} }
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册