未验证 提交 eca4dcbb 编写于 作者: C cc 提交者: GitHub

Optimize obtaining quantized scale in inference stage, test=develop (#4308)

* Optimize obtaining quantized scale in inference, test=develop
上级 e45c4242
......@@ -233,67 +233,97 @@ bool OpInfo::GetOutputIndex(const std::string &output_name, int *out) const {
return false;
}
bool OpInfo::HasInputScale(const std::string &input_name) const {
std::string argname;
int index;
if (GetInputArgname(input_name, &argname) &&
GetInputIndex(input_name, &index)) {
return HasAttr(argname + to_string(index) + "_scale");
bool OpInfo::HasInputScale(const std::string &name, bool is_scale_name) const {
bool res = false;
if (is_scale_name) {
res = HasAttr(name);
} else {
return false;
std::string argname;
int index;
if (GetInputArgname(name, &argname) && GetInputIndex(name, &index)) {
res = HasAttr(argname + to_string(index) + "_scale");
}
}
return res;
}
bool OpInfo::HasOutputScale(const std::string &output_name) const {
std::string argname;
int index;
if (GetOutputArgname(output_name, &argname) &&
GetOutputIndex(output_name, &index)) {
return HasAttr(argname + to_string(index) + "_scale");
bool OpInfo::HasOutputScale(const std::string &name, bool is_scale_name) const {
bool res = false;
if (is_scale_name) {
res = HasAttr(name);
} else {
return false;
std::string argname;
int index;
if (GetOutputArgname(name, &argname) && GetOutputIndex(name, &index)) {
res = HasAttr(argname + to_string(index) + "_scale");
}
}
return res;
}
void OpInfo::SetInputScale(const std::string &input_name,
const std::vector<float> &scale_value) {
std::string argname;
int index;
CHECK(GetInputArgname(input_name, &argname));
CHECK(GetInputIndex(input_name, &index));
CHECK(scale_value.size() > 0)
<< "Error in SetInputScale: the scales should not be empty";
SetAttr<std::vector<float>>(argname + to_string(index) + "_scale",
scale_value);
void OpInfo::SetInputScale(const std::string &name,
const std::vector<float> &scale_value,
bool is_scale_name) {
std::string scale_name;
if (is_scale_name) {
scale_name = name;
} else {
std::string argname;
int index;
CHECK(GetInputArgname(name, &argname));
CHECK(GetInputIndex(name, &index));
CHECK(scale_value.size() > 0)
<< "Error in SetInputScale: the scales should not be empty";
scale_name = argname + to_string(index) + "_scale";
}
SetAttr<std::vector<float>>(scale_name, scale_value);
}
void OpInfo::SetOutputScale(const std::string &output_name,
const std::vector<float> &scale_value) {
std::string argname;
int index;
CHECK(GetOutputArgname(output_name, &argname));
CHECK(GetOutputIndex(output_name, &index));
CHECK(scale_value.size() > 0)
<< "Error in SetOutputScale: the scales should not be empty";
SetAttr<std::vector<float>>(argname + to_string(index) + "_scale",
scale_value);
void OpInfo::SetOutputScale(const std::string &name,
const std::vector<float> &scale_value,
bool is_scale_name) {
std::string scale_name;
if (is_scale_name) {
scale_name = name;
} else {
std::string argname;
int index;
CHECK(GetOutputArgname(name, &argname));
CHECK(GetOutputIndex(name, &index));
CHECK(scale_value.size() > 0)
<< "Error in SetOutputScale: the scales should not be empty";
scale_name = argname + to_string(index) + "_scale";
}
SetAttr<std::vector<float>>(scale_name, scale_value);
}
std::vector<float> OpInfo::GetInputScale(const std::string &input_name) const {
std::string argname;
int index;
CHECK(GetInputArgname(input_name, &argname));
CHECK(GetInputIndex(input_name, &index));
return GetAttr<std::vector<float>>(argname + to_string(index) + "_scale");
std::vector<float> OpInfo::GetInputScale(const std::string &name,
bool is_scale_name) const {
std::string scale_name;
if (is_scale_name) {
scale_name = name;
} else {
std::string argname;
int index;
CHECK(GetInputArgname(name, &argname));
CHECK(GetInputIndex(name, &index));
scale_name = argname + to_string(index) + "_scale";
}
return GetAttr<std::vector<float>>(scale_name);
}
std::vector<float> OpInfo::GetOutputScale(
const std::string &output_name) const {
std::string argname;
int index;
CHECK(GetOutputArgname(output_name, &argname));
CHECK(GetOutputIndex(output_name, &index));
return GetAttr<std::vector<float>>(argname + to_string(index) + "_scale");
std::vector<float> OpInfo::GetOutputScale(const std::string &name,
bool is_scale_name) const {
std::string scale_name;
if (is_scale_name) {
scale_name = name;
} else {
std::string argname;
int index;
CHECK(GetOutputArgname(name, &argname));
CHECK(GetOutputIndex(name, &index));
}
return GetAttr<std::vector<float>>(scale_name);
}
} // namespace lite
......
......@@ -251,19 +251,31 @@ class OpInfo : public cpp::OpDesc {
bool GetInputIndex(const std::string &input_name, int *out) const;
bool GetOutputIndex(const std::string &output_name, int *out) const;
bool HasInputScale(const std::string &input_name) const;
bool HasOutputScale(const std::string &output_name) const;
// If a quantized op has two input argname (X, Y) and one output
// argname (Out). The scales of input argname X are saved in op desc as
// (X0_scale, scale_value_0), (X1_scale, scale_value_1)...
// The following APIs get or set the quantized scale in op_desc.
// If use the input or output name, the is_scale_name should be false.
// If use the scale_name such as (X0_scale, scale_value_0),
// the is_scale_name should be true.
bool HasInputScale(const std::string &name, bool is_scale_name = false) const;
bool HasOutputScale(const std::string &name,
bool is_scale_name = false) const;
void SetInputScale(const std::string &input_name,
const std::vector<float> &scale_value);
const std::vector<float> &scale_value,
bool is_scale_name = false);
void SetOutputScale(const std::string &output_name,
const std::vector<float> &scale_value);
const std::vector<float> &scale_value,
bool is_scale_name = false);
// For conv2d, depthwise_conv2d and mul, the scale of weight are a vector.
// Otherwise, all input and output scales are scalar, but we save these
// as vecotr.
std::vector<float> GetInputScale(const std::string &input_name) const;
std::vector<float> GetOutputScale(const std::string &output_name) const;
std::vector<float> GetInputScale(const std::string &name,
bool is_scale_name = false) const;
std::vector<float> GetOutputScale(const std::string &name,
bool is_scale_name = false) const;
};
} // namespace lite
......
......@@ -133,15 +133,16 @@ class ConvOpLite : public OpLite {
const OpInfo* op_info = dynamic_cast<const OpInfo*>(&op_desc);
if (op_info != nullptr && op_info->HasAttr("enable_int8")) {
param_.enable_int8 = op_info->GetAttr<bool>("enable_int8");
auto input_name = op_info->Input("Input").front();
auto filter_name = op_info->Input("Filter").front();
auto output_name = op_info->Output("Output").front();
if (op_info->HasInputScale(input_name))
param_.input_scale = op_info->GetInputScale(input_name)[0];
if (op_info->HasInputScale(filter_name))
param_.weight_scale = op_info->GetInputScale(filter_name);
if (op_info->HasOutputScale(output_name)) {
param_.output_scale = op_info->GetOutputScale(output_name)[0];
auto input_scale_name = "Input0_scale";
auto filter_scale_name = "Filter0_scale";
auto output_scale_name = "Output0_scale";
if (op_info->HasInputScale(input_scale_name, true))
param_.input_scale = op_info->GetInputScale(input_scale_name, true)[0];
if (op_info->HasInputScale(filter_scale_name, true))
param_.weight_scale = op_info->GetInputScale(filter_scale_name, true);
if (op_info->HasOutputScale(output_scale_name, true)) {
param_.output_scale =
op_info->GetOutputScale(output_scale_name, true)[0];
}
}
......
......@@ -112,15 +112,15 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
const OpInfo* op_info = dynamic_cast<const OpInfo*>(&op_desc);
if (op_info != nullptr && op_info->HasAttr("enable_int8")) {
param_.enable_int8 = op_info->GetAttr<bool>("enable_int8");
auto input_name = op_info->Input("Input").front();
auto weight_name = op_info->Input("W").front();
auto out_name = op_info->Output("Out").front();
if (op_info->HasInputScale(input_name))
param_.input_scale = op_info->GetInputScale(input_name)[0];
if (op_info->HasInputScale(weight_name))
param_.weight_scale = op_info->GetInputScale(weight_name);
if (op_info->HasOutputScale(out_name))
param_.output_scale = op_info->GetOutputScale(out_name)[0];
auto input_scale_name = "Input0_scale";
auto weight_scale_name = "W0_scale";
auto out_scale_name = "Out0_scale";
if (op_info->HasInputScale(input_scale_name, true))
param_.input_scale = op_info->GetInputScale(input_scale_name, true)[0];
if (op_info->HasInputScale(weight_scale_name, true))
param_.weight_scale = op_info->GetInputScale(weight_scale_name, true);
if (op_info->HasOutputScale(out_scale_name, true))
param_.output_scale = op_info->GetOutputScale(out_scale_name, true)[0];
}
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册