未验证 提交 6eaa2d7c 编写于 作者: H huzhiqiang 提交者: GitHub

[BUG FIX] fix the issue that opt can not convert quantized model (#3678)

上级 e728a406
......@@ -73,7 +73,7 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
std::unordered_map<std::string, cpp::VarDesc> origin_var_maps;
auto& main_block = *desc->GetBlock<cpp::BlockDesc>(0);
auto var_size = main_block.VarsSize();
for (size_t i = 0; i < var_size; i++) {
for (int i = 0; i < var_size; i++) {
auto v = main_block.GetVar<cpp::VarDesc>(i);
auto name = v->Name();
origin_var_maps.emplace(name, *v);
......@@ -86,16 +86,12 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
auto* scope = op->scope();
auto in_names = op->op_info()->input_names();
auto out_names = op->op_info()->output_names();
std::vector<std::string> var_names;
var_names.insert(var_names.end(), in_names.begin(), in_names.end());
var_names.insert(var_names.end(), out_names.begin(), out_names.end());
std::sort(var_names.begin(), var_names.end());
var_names.erase(std::unique(var_names.begin(), var_names.end()),
var_names.end());
for (auto& var_name : var_names) {
auto it = origin_var_maps.find(var_name);
in_names.insert(in_names.end(), out_names.begin(), out_names.end());
std::sort(in_names.begin(), in_names.end());
in_names.erase(std::unique(in_names.begin(), in_names.end()),
in_names.end());
for (auto& in_name : in_names) {
auto it = origin_var_maps.find(in_name);
if (it != origin_var_maps.end()) {
auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName((it->second).Name());
......@@ -108,30 +104,37 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
} else {
// New created vars must be LOD_TENSOR
auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName(var_name);
v->SetName(in_name);
v->SetType(cpp::VarDesc::Type::LOD_TENSOR);
std::string in_arg_name;
op->op_info()->GetInputArgname(var_name, &in_arg_name);
auto type = kernel->GetInputDeclType(in_arg_name);
const Type* type;
if (op->op_info()->GetInputArgname(in_name, &in_arg_name)) {
type = kernel->GetInputDeclType(in_arg_name);
} else {
op->op_info()->GetOutputArgname(in_name, &in_arg_name);
type = kernel->GetOutputDeclType(in_arg_name);
}
if (type->IsTensor()) {
auto tensor = scope->FindVar(var_name)->GetMutable<Tensor>();
auto tensor = scope->FindVar(in_name)->GetMutable<Tensor>();
v->SetPersistable(tensor->persistable());
if ((it->second).Name() != "feed" && (it->second).Name() != "fetch") {
if (in_name != "feed" && in_name != "fetch") {
v->SetShape(tensor->dims().data());
switch (tensor->precision()) {
#define SET_DATATYPE(precision__, data_type) \
case PrecisionType::precision__: \
v->SetDataType(data_type); \
#define SET_DATATYPE(precision__, data_type) \
case PrecisionType::precision__: \
v->SetDataType(data_type); \
LOG(INFO) << "update var" << (it->second).Name() << "done"; \
break
SET_DATATYPE(kBool, VarDescAPI::VarDataType::BOOL);
SET_DATATYPE(kFloat, VarDescAPI::VarDataType::FP32);
SET_DATATYPE(kFP16, VarDescAPI::VarDataType::FP16);
SET_DATATYPE(kInt8, VarDescAPI::VarDataType::INT8);
SET_DATATYPE(kInt16, VarDescAPI::VarDataType::INT16);
SET_DATATYPE(kInt32, VarDescAPI::VarDataType::INT32);
SET_DATATYPE(kInt64, VarDescAPI::VarDataType::INT64);
#undef SET_DATATYPE
default:
LOG(FATAL) << "unknown precision type";
VLOG(4) << "warning! unknown precision type";
}
}
} else {
......@@ -141,7 +144,6 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
}
}
}
void RuntimeProgram::Run() {
#ifdef LITE_WITH_PRECISION_PROFILE
auto inst_precision_profiler = paddle::lite::profile::PrecisionProfiler();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册