From b835d958e53ab7f18f77cbb9797e607f83db4447 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Wed, 12 Apr 2023 10:30:03 +0800 Subject: [PATCH] fix convert_to_mixed_precision api save model bug (#52767) * update save model * update --- .../passes/convert_to_mixed_precision.cc | 57 +++++++++++++------ 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index 2589a20eb28..963197850c9 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -102,32 +102,53 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { framework::ProgramDesc mixed_program_desc; framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc); - auto parameters = scope_.LocalVarNames(); - std::sort(parameters.begin(), parameters.end()); - - auto SerializeParams = [&]() -> std::string { - std::ostringstream os; - phi::CPUContext ctx; - for (const auto& param : parameters) { - PADDLE_ENFORCE_NOT_NULL( - scope_.FindVar(param), - platform::errors::NotFound( - "Block should already have a '%s' variable", param)); - auto* tensor = scope_.FindVar(param)->GetMutable(); - framework::SerializeToStream(os, *tensor, ctx); + auto SerializeParams = [&](const std::string& path) { + auto IsPersistable = [](const framework::VarDesc* var) { + if (var->Persistable() && + var->GetType() != framework::proto::VarType::FEED_MINIBATCH && + var->GetType() != framework::proto::VarType::FETCH_LIST && + var->GetType() != framework::proto::VarType::RAW) { + return true; + } + return false; + }; + framework::ProgramDesc save_program; + auto* save_block = save_program.MutableBlock(0); + + const auto& global_block = mixed_program_desc.Block(0); + std::vector save_var_list; + for (framework::VarDesc* var : global_block.AllVars()) { + if (IsPersistable(var)) { + framework::VarDesc* new_var = save_block->Var(var->Name()); + new_var->SetShape(var->GetShape()); + new_var->SetDataType(var->GetDataType()); + new_var->SetType(var->GetType()); + new_var->SetLoDLevel(var->GetLoDLevel()); + new_var->SetPersistable(true); + + save_var_list.push_back(new_var->Name()); + } } - return os.str(); + std::sort(save_var_list.begin(), save_var_list.end()); + auto* op = save_block->AppendOp(); + op->SetType("save_combine"); + op->SetInput("X", save_var_list); + op->SetAttr("file_path", path); + op->CheckAttrs(); + + framework::Executor exe(platform::CPUPlace{}); + exe.Run(save_program, &scope_, 0, true, true); }; - auto StrToBinary = [](const std::string& path, const std::string& str) { + auto SerializeProg = [&](const std::string& path) { + auto str = mixed_program_desc.Proto()->SerializeAsString(); std::ofstream file(path.c_str(), std::ios::binary); file.write(str.c_str(), str.size()); file.close(); }; - StrToBinary(mixed_model_file_, - mixed_program_desc.Proto()->SerializeAsString()); - StrToBinary(mixed_params_file_, SerializeParams()); + SerializeProg(mixed_model_file_); + SerializeParams(mixed_params_file_); } bool OpSupportPrecision(const std::string& op_type, -- GitLab