From f779a5b9e4467a36afe5aaf0c9219b75f3c86d51 Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Fri, 22 May 2020 20:07:12 +0800 Subject: [PATCH] [cherry-pick][BUG FIX] fix the issue that opt can not convert quantized model (#3683) --- lite/core/program.cc | 46 +++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/lite/core/program.cc b/lite/core/program.cc index 9c864ebea5..9a6790f654 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -73,7 +73,7 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { std::unordered_map origin_var_maps; auto& main_block = *desc->GetBlock(0); auto var_size = main_block.VarsSize(); - for (size_t i = 0; i < var_size; i++) { + for (int i = 0; i < var_size; i++) { auto v = main_block.GetVar(i); auto name = v->Name(); origin_var_maps.emplace(name, *v); @@ -86,16 +86,12 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { auto* scope = op->scope(); auto in_names = op->op_info()->input_names(); auto out_names = op->op_info()->output_names(); - - std::vector var_names; - var_names.insert(var_names.end(), in_names.begin(), in_names.end()); - var_names.insert(var_names.end(), out_names.begin(), out_names.end()); - std::sort(var_names.begin(), var_names.end()); - var_names.erase(std::unique(var_names.begin(), var_names.end()), - var_names.end()); - - for (auto& var_name : var_names) { - auto it = origin_var_maps.find(var_name); + in_names.insert(in_names.end(), out_names.begin(), out_names.end()); + std::sort(in_names.begin(), in_names.end()); + in_names.erase(std::unique(in_names.begin(), in_names.end()), + in_names.end()); + for (auto& in_name : in_names) { + auto it = origin_var_maps.find(in_name); if (it != origin_var_maps.end()) { auto* v = main_block.AddVar(); v->SetName((it->second).Name()); @@ -108,30 +104,37 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { } else { // New created vars must be LOD_TENSOR auto* v = main_block.AddVar(); - v->SetName(var_name); + v->SetName(in_name); v->SetType(cpp::VarDesc::Type::LOD_TENSOR); std::string in_arg_name; - op->op_info()->GetInputArgname(var_name, &in_arg_name); - auto type = kernel->GetInputDeclType(in_arg_name); + const Type* type; + if (op->op_info()->GetInputArgname(in_name, &in_arg_name)) { + type = kernel->GetInputDeclType(in_arg_name); + } else { + op->op_info()->GetOutputArgname(in_name, &in_arg_name); + type = kernel->GetOutputDeclType(in_arg_name); + } if (type->IsTensor()) { - auto tensor = scope->FindVar(var_name)->GetMutable(); + auto tensor = scope->FindVar(in_name)->GetMutable(); v->SetPersistable(tensor->persistable()); - if ((it->second).Name() != "feed" && (it->second).Name() != "fetch") { + if (in_name != "feed" && in_name != "fetch") { v->SetShape(tensor->dims().data()); switch (tensor->precision()) { -#define SET_DATATYPE(precision__, data_type) \ - case PrecisionType::precision__: \ - v->SetDataType(data_type); \ +#define SET_DATATYPE(precision__, data_type) \ + case PrecisionType::precision__: \ + v->SetDataType(data_type); \ + LOG(INFO) << "update var" << (it->second).Name() << "done"; \ break - + SET_DATATYPE(kBool, VarDescAPI::VarDataType::BOOL); SET_DATATYPE(kFloat, VarDescAPI::VarDataType::FP32); + SET_DATATYPE(kFP16, VarDescAPI::VarDataType::FP16); SET_DATATYPE(kInt8, VarDescAPI::VarDataType::INT8); SET_DATATYPE(kInt16, VarDescAPI::VarDataType::INT16); SET_DATATYPE(kInt32, VarDescAPI::VarDataType::INT32); SET_DATATYPE(kInt64, VarDescAPI::VarDataType::INT64); #undef SET_DATATYPE default: - LOG(FATAL) << "unknown precision type"; + VLOG(4) << "warning! unknown precision type"; } } } else { @@ -141,7 +144,6 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { } } } - void RuntimeProgram::Run() { #ifdef LITE_WITH_PRECISION_PROFILE auto inst_precision_profiler = paddle::lite::profile::PrecisionProfiler(); -- GitLab