diff --git a/lite/core/program.cc b/lite/core/program.cc index 9c864ebea50ed07f514a32328a196dc4eedfcf72..9a6790f65430007b490030c338db1232403fda8e 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -73,7 +73,7 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { std::unordered_map origin_var_maps; auto& main_block = *desc->GetBlock(0); auto var_size = main_block.VarsSize(); - for (size_t i = 0; i < var_size; i++) { + for (int i = 0; i < var_size; i++) { auto v = main_block.GetVar(i); auto name = v->Name(); origin_var_maps.emplace(name, *v); @@ -86,16 +86,12 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { auto* scope = op->scope(); auto in_names = op->op_info()->input_names(); auto out_names = op->op_info()->output_names(); - - std::vector var_names; - var_names.insert(var_names.end(), in_names.begin(), in_names.end()); - var_names.insert(var_names.end(), out_names.begin(), out_names.end()); - std::sort(var_names.begin(), var_names.end()); - var_names.erase(std::unique(var_names.begin(), var_names.end()), - var_names.end()); - - for (auto& var_name : var_names) { - auto it = origin_var_maps.find(var_name); + in_names.insert(in_names.end(), out_names.begin(), out_names.end()); + std::sort(in_names.begin(), in_names.end()); + in_names.erase(std::unique(in_names.begin(), in_names.end()), + in_names.end()); + for (auto& in_name : in_names) { + auto it = origin_var_maps.find(in_name); if (it != origin_var_maps.end()) { auto* v = main_block.AddVar(); v->SetName((it->second).Name()); @@ -108,30 +104,37 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { } else { // New created vars must be LOD_TENSOR auto* v = main_block.AddVar(); - v->SetName(var_name); + v->SetName(in_name); v->SetType(cpp::VarDesc::Type::LOD_TENSOR); std::string in_arg_name; - op->op_info()->GetInputArgname(var_name, &in_arg_name); - auto type = kernel->GetInputDeclType(in_arg_name); + const Type* type; + if (op->op_info()->GetInputArgname(in_name, &in_arg_name)) { + type = kernel->GetInputDeclType(in_arg_name); + } else { + op->op_info()->GetOutputArgname(in_name, &in_arg_name); + type = kernel->GetOutputDeclType(in_arg_name); + } if (type->IsTensor()) { - auto tensor = scope->FindVar(var_name)->GetMutable(); + auto tensor = scope->FindVar(in_name)->GetMutable(); v->SetPersistable(tensor->persistable()); - if ((it->second).Name() != "feed" && (it->second).Name() != "fetch") { + if (in_name != "feed" && in_name != "fetch") { v->SetShape(tensor->dims().data()); switch (tensor->precision()) { -#define SET_DATATYPE(precision__, data_type) \ - case PrecisionType::precision__: \ - v->SetDataType(data_type); \ +#define SET_DATATYPE(precision__, data_type) \ + case PrecisionType::precision__: \ + v->SetDataType(data_type); \ + LOG(INFO) << "update var" << (it->second).Name() << "done"; \ break - + SET_DATATYPE(kBool, VarDescAPI::VarDataType::BOOL); SET_DATATYPE(kFloat, VarDescAPI::VarDataType::FP32); + SET_DATATYPE(kFP16, VarDescAPI::VarDataType::FP16); SET_DATATYPE(kInt8, VarDescAPI::VarDataType::INT8); SET_DATATYPE(kInt16, VarDescAPI::VarDataType::INT16); SET_DATATYPE(kInt32, VarDescAPI::VarDataType::INT32); SET_DATATYPE(kInt64, VarDescAPI::VarDataType::INT64); #undef SET_DATATYPE default: - LOG(FATAL) << "unknown precision type"; + VLOG(4) << "warning! unknown precision type"; } } } else { @@ -141,7 +144,6 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { } } } - void RuntimeProgram::Run() { #ifdef LITE_WITH_PRECISION_PROFILE auto inst_precision_profiler = paddle::lite::profile::PrecisionProfiler();