提交 6dcbe32b 编写于 作者: H huzhiqiang 提交者: GitHub

[cherry-pick][BUG FIX] fix the issue that opt can not convert quantized model (#3683)

上级 6acf04c6
...@@ -73,7 +73,7 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { ...@@ -73,7 +73,7 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
std::unordered_map<std::string, cpp::VarDesc> origin_var_maps; std::unordered_map<std::string, cpp::VarDesc> origin_var_maps;
auto& main_block = *desc->GetBlock<cpp::BlockDesc>(0); auto& main_block = *desc->GetBlock<cpp::BlockDesc>(0);
auto var_size = main_block.VarsSize(); auto var_size = main_block.VarsSize();
for (size_t i = 0; i < var_size; i++) { for (int i = 0; i < var_size; i++) {
auto v = main_block.GetVar<cpp::VarDesc>(i); auto v = main_block.GetVar<cpp::VarDesc>(i);
auto name = v->Name(); auto name = v->Name();
origin_var_maps.emplace(name, *v); origin_var_maps.emplace(name, *v);
...@@ -86,16 +86,12 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { ...@@ -86,16 +86,12 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
auto* scope = op->scope(); auto* scope = op->scope();
auto in_names = op->op_info()->input_names(); auto in_names = op->op_info()->input_names();
auto out_names = op->op_info()->output_names(); auto out_names = op->op_info()->output_names();
in_names.insert(in_names.end(), out_names.begin(), out_names.end());
std::vector<std::string> var_names; std::sort(in_names.begin(), in_names.end());
var_names.insert(var_names.end(), in_names.begin(), in_names.end()); in_names.erase(std::unique(in_names.begin(), in_names.end()),
var_names.insert(var_names.end(), out_names.begin(), out_names.end()); in_names.end());
std::sort(var_names.begin(), var_names.end()); for (auto& in_name : in_names) {
var_names.erase(std::unique(var_names.begin(), var_names.end()), auto it = origin_var_maps.find(in_name);
var_names.end());
for (auto& var_name : var_names) {
auto it = origin_var_maps.find(var_name);
if (it != origin_var_maps.end()) { if (it != origin_var_maps.end()) {
auto* v = main_block.AddVar<cpp::VarDesc>(); auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName((it->second).Name()); v->SetName((it->second).Name());
...@@ -108,30 +104,37 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { ...@@ -108,30 +104,37 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
} else { } else {
// New created vars must be LOD_TENSOR // New created vars must be LOD_TENSOR
auto* v = main_block.AddVar<cpp::VarDesc>(); auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName(var_name); v->SetName(in_name);
v->SetType(cpp::VarDesc::Type::LOD_TENSOR); v->SetType(cpp::VarDesc::Type::LOD_TENSOR);
std::string in_arg_name; std::string in_arg_name;
op->op_info()->GetInputArgname(var_name, &in_arg_name); const Type* type;
auto type = kernel->GetInputDeclType(in_arg_name); if (op->op_info()->GetInputArgname(in_name, &in_arg_name)) {
type = kernel->GetInputDeclType(in_arg_name);
} else {
op->op_info()->GetOutputArgname(in_name, &in_arg_name);
type = kernel->GetOutputDeclType(in_arg_name);
}
if (type->IsTensor()) { if (type->IsTensor()) {
auto tensor = scope->FindVar(var_name)->GetMutable<Tensor>(); auto tensor = scope->FindVar(in_name)->GetMutable<Tensor>();
v->SetPersistable(tensor->persistable()); v->SetPersistable(tensor->persistable());
if ((it->second).Name() != "feed" && (it->second).Name() != "fetch") { if (in_name != "feed" && in_name != "fetch") {
v->SetShape(tensor->dims().data()); v->SetShape(tensor->dims().data());
switch (tensor->precision()) { switch (tensor->precision()) {
#define SET_DATATYPE(precision__, data_type) \ #define SET_DATATYPE(precision__, data_type) \
case PrecisionType::precision__: \ case PrecisionType::precision__: \
v->SetDataType(data_type); \ v->SetDataType(data_type); \
LOG(INFO) << "update var" << (it->second).Name() << "done"; \
break break
SET_DATATYPE(kBool, VarDescAPI::VarDataType::BOOL);
SET_DATATYPE(kFloat, VarDescAPI::VarDataType::FP32); SET_DATATYPE(kFloat, VarDescAPI::VarDataType::FP32);
SET_DATATYPE(kFP16, VarDescAPI::VarDataType::FP16);
SET_DATATYPE(kInt8, VarDescAPI::VarDataType::INT8); SET_DATATYPE(kInt8, VarDescAPI::VarDataType::INT8);
SET_DATATYPE(kInt16, VarDescAPI::VarDataType::INT16); SET_DATATYPE(kInt16, VarDescAPI::VarDataType::INT16);
SET_DATATYPE(kInt32, VarDescAPI::VarDataType::INT32); SET_DATATYPE(kInt32, VarDescAPI::VarDataType::INT32);
SET_DATATYPE(kInt64, VarDescAPI::VarDataType::INT64); SET_DATATYPE(kInt64, VarDescAPI::VarDataType::INT64);
#undef SET_DATATYPE #undef SET_DATATYPE
default: default:
LOG(FATAL) << "unknown precision type"; VLOG(4) << "warning! unknown precision type";
} }
} }
} else { } else {
...@@ -141,7 +144,6 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { ...@@ -141,7 +144,6 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
} }
} }
} }
void RuntimeProgram::Run() { void RuntimeProgram::Run() {
#ifdef LITE_WITH_PRECISION_PROFILE #ifdef LITE_WITH_PRECISION_PROFILE
auto inst_precision_profiler = paddle::lite::profile::PrecisionProfiler(); auto inst_precision_profiler = paddle::lite::profile::PrecisionProfiler();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册