diff --git a/lite/core/op_lite.cc b/lite/core/op_lite.cc index 585aaf3b703bca0a0a34030106dbf793e2a31d52..c3c00d0fa0edeaf3f7070b3e912b9891e44fdc86 100644 --- a/lite/core/op_lite.cc +++ b/lite/core/op_lite.cc @@ -233,67 +233,97 @@ bool OpInfo::GetOutputIndex(const std::string &output_name, int *out) const { return false; } -bool OpInfo::HasInputScale(const std::string &input_name) const { - std::string argname; - int index; - if (GetInputArgname(input_name, &argname) && - GetInputIndex(input_name, &index)) { - return HasAttr(argname + to_string(index) + "_scale"); +bool OpInfo::HasInputScale(const std::string &name, bool is_scale_name) const { + bool res = false; + if (is_scale_name) { + res = HasAttr(name); } else { - return false; + std::string argname; + int index; + if (GetInputArgname(name, &argname) && GetInputIndex(name, &index)) { + res = HasAttr(argname + to_string(index) + "_scale"); + } } + return res; } -bool OpInfo::HasOutputScale(const std::string &output_name) const { - std::string argname; - int index; - if (GetOutputArgname(output_name, &argname) && - GetOutputIndex(output_name, &index)) { - return HasAttr(argname + to_string(index) + "_scale"); +bool OpInfo::HasOutputScale(const std::string &name, bool is_scale_name) const { + bool res = false; + if (is_scale_name) { + res = HasAttr(name); } else { - return false; + std::string argname; + int index; + if (GetOutputArgname(name, &argname) && GetOutputIndex(name, &index)) { + res = HasAttr(argname + to_string(index) + "_scale"); + } } + return res; } -void OpInfo::SetInputScale(const std::string &input_name, - const std::vector &scale_value) { - std::string argname; - int index; - CHECK(GetInputArgname(input_name, &argname)); - CHECK(GetInputIndex(input_name, &index)); - CHECK(scale_value.size() > 0) - << "Error in SetInputScale: the scales should not be empty"; - SetAttr>(argname + to_string(index) + "_scale", - scale_value); +void OpInfo::SetInputScale(const std::string &name, + const std::vector &scale_value, + bool is_scale_name) { + std::string scale_name; + if (is_scale_name) { + scale_name = name; + } else { + std::string argname; + int index; + CHECK(GetInputArgname(name, &argname)); + CHECK(GetInputIndex(name, &index)); + CHECK(scale_value.size() > 0) + << "Error in SetInputScale: the scales should not be empty"; + scale_name = argname + to_string(index) + "_scale"; + } + SetAttr>(scale_name, scale_value); } -void OpInfo::SetOutputScale(const std::string &output_name, - const std::vector &scale_value) { - std::string argname; - int index; - CHECK(GetOutputArgname(output_name, &argname)); - CHECK(GetOutputIndex(output_name, &index)); - CHECK(scale_value.size() > 0) - << "Error in SetOutputScale: the scales should not be empty"; - SetAttr>(argname + to_string(index) + "_scale", - scale_value); +void OpInfo::SetOutputScale(const std::string &name, + const std::vector &scale_value, + bool is_scale_name) { + std::string scale_name; + if (is_scale_name) { + scale_name = name; + } else { + std::string argname; + int index; + CHECK(GetOutputArgname(name, &argname)); + CHECK(GetOutputIndex(name, &index)); + CHECK(scale_value.size() > 0) + << "Error in SetOutputScale: the scales should not be empty"; + scale_name = argname + to_string(index) + "_scale"; + } + SetAttr>(scale_name, scale_value); } -std::vector OpInfo::GetInputScale(const std::string &input_name) const { - std::string argname; - int index; - CHECK(GetInputArgname(input_name, &argname)); - CHECK(GetInputIndex(input_name, &index)); - return GetAttr>(argname + to_string(index) + "_scale"); +std::vector OpInfo::GetInputScale(const std::string &name, + bool is_scale_name) const { + std::string scale_name; + if (is_scale_name) { + scale_name = name; + } else { + std::string argname; + int index; + CHECK(GetInputArgname(name, &argname)); + CHECK(GetInputIndex(name, &index)); + scale_name = argname + to_string(index) + "_scale"; + } + return GetAttr>(scale_name); } -std::vector OpInfo::GetOutputScale( - const std::string &output_name) const { - std::string argname; - int index; - CHECK(GetOutputArgname(output_name, &argname)); - CHECK(GetOutputIndex(output_name, &index)); - return GetAttr>(argname + to_string(index) + "_scale"); +std::vector OpInfo::GetOutputScale(const std::string &name, + bool is_scale_name) const { + std::string scale_name; + if (is_scale_name) { + scale_name = name; + } else { + std::string argname; + int index; + CHECK(GetOutputArgname(name, &argname)); + CHECK(GetOutputIndex(name, &index)); + } + return GetAttr>(scale_name); } } // namespace lite diff --git a/lite/core/op_lite.h b/lite/core/op_lite.h index d94753220a1b5d963092c62c43d7e49b03243c63..1e664152a39110bdfc28cbb037920b6174315aa5 100644 --- a/lite/core/op_lite.h +++ b/lite/core/op_lite.h @@ -251,19 +251,31 @@ class OpInfo : public cpp::OpDesc { bool GetInputIndex(const std::string &input_name, int *out) const; bool GetOutputIndex(const std::string &output_name, int *out) const; - bool HasInputScale(const std::string &input_name) const; - bool HasOutputScale(const std::string &output_name) const; + // If a quantized op has two input argname (X, Y) and one output + // argname (Out). The scales of input argname X are saved in op desc as + // (X0_scale, scale_value_0), (X1_scale, scale_value_1)... + // The following APIs get or set the quantized scale in op_desc. + // If use the input or output name, the is_scale_name should be false. + // If use the scale_name such as (X0_scale, scale_value_0), + // the is_scale_name should be true. + bool HasInputScale(const std::string &name, bool is_scale_name = false) const; + bool HasOutputScale(const std::string &name, + bool is_scale_name = false) const; void SetInputScale(const std::string &input_name, - const std::vector &scale_value); + const std::vector &scale_value, + bool is_scale_name = false); void SetOutputScale(const std::string &output_name, - const std::vector &scale_value); + const std::vector &scale_value, + bool is_scale_name = false); // For conv2d, depthwise_conv2d and mul, the scale of weight are a vector. // Otherwise, all input and output scales are scalar, but we save these // as vecotr. - std::vector GetInputScale(const std::string &input_name) const; - std::vector GetOutputScale(const std::string &output_name) const; + std::vector GetInputScale(const std::string &name, + bool is_scale_name = false) const; + std::vector GetOutputScale(const std::string &name, + bool is_scale_name = false) const; }; } // namespace lite diff --git a/lite/operators/conv_op.h b/lite/operators/conv_op.h index a1d4e2e8a038046b257b3ab5f936cc4cb2e62c67..38ef1c6878db570d401bafdb0656a368d377eb46 100644 --- a/lite/operators/conv_op.h +++ b/lite/operators/conv_op.h @@ -133,15 +133,16 @@ class ConvOpLite : public OpLite { const OpInfo* op_info = dynamic_cast(&op_desc); if (op_info != nullptr && op_info->HasAttr("enable_int8")) { param_.enable_int8 = op_info->GetAttr("enable_int8"); - auto input_name = op_info->Input("Input").front(); - auto filter_name = op_info->Input("Filter").front(); - auto output_name = op_info->Output("Output").front(); - if (op_info->HasInputScale(input_name)) - param_.input_scale = op_info->GetInputScale(input_name)[0]; - if (op_info->HasInputScale(filter_name)) - param_.weight_scale = op_info->GetInputScale(filter_name); - if (op_info->HasOutputScale(output_name)) { - param_.output_scale = op_info->GetOutputScale(output_name)[0]; + auto input_scale_name = "Input0_scale"; + auto filter_scale_name = "Filter0_scale"; + auto output_scale_name = "Output0_scale"; + if (op_info->HasInputScale(input_scale_name, true)) + param_.input_scale = op_info->GetInputScale(input_scale_name, true)[0]; + if (op_info->HasInputScale(filter_scale_name, true)) + param_.weight_scale = op_info->GetInputScale(filter_scale_name, true); + if (op_info->HasOutputScale(output_scale_name, true)) { + param_.output_scale = + op_info->GetOutputScale(output_scale_name, true)[0]; } } diff --git a/lite/operators/fc_op.cc b/lite/operators/fc_op.cc index 5d60af4af075ac11b936868ed822a28e55baef6b..e776f747fc1278f0f2cb75d6e379843b78c7e3fc 100644 --- a/lite/operators/fc_op.cc +++ b/lite/operators/fc_op.cc @@ -112,15 +112,15 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { const OpInfo* op_info = dynamic_cast(&op_desc); if (op_info != nullptr && op_info->HasAttr("enable_int8")) { param_.enable_int8 = op_info->GetAttr("enable_int8"); - auto input_name = op_info->Input("Input").front(); - auto weight_name = op_info->Input("W").front(); - auto out_name = op_info->Output("Out").front(); - if (op_info->HasInputScale(input_name)) - param_.input_scale = op_info->GetInputScale(input_name)[0]; - if (op_info->HasInputScale(weight_name)) - param_.weight_scale = op_info->GetInputScale(weight_name); - if (op_info->HasOutputScale(out_name)) - param_.output_scale = op_info->GetOutputScale(out_name)[0]; + auto input_scale_name = "Input0_scale"; + auto weight_scale_name = "W0_scale"; + auto out_scale_name = "Out0_scale"; + if (op_info->HasInputScale(input_scale_name, true)) + param_.input_scale = op_info->GetInputScale(input_scale_name, true)[0]; + if (op_info->HasInputScale(weight_scale_name, true)) + param_.weight_scale = op_info->GetInputScale(weight_scale_name, true); + if (op_info->HasOutputScale(out_scale_name, true)) + param_.output_scale = op_info->GetOutputScale(out_scale_name, true)[0]; } return true; }