未验证 提交 b835d958 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix convert_to_mixed_precision api save model bug (#52767)

* update save model

* update
上级 9a7c83bd
...@@ -102,32 +102,53 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { ...@@ -102,32 +102,53 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
framework::ProgramDesc mixed_program_desc; framework::ProgramDesc mixed_program_desc;
framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc); framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc);
auto parameters = scope_.LocalVarNames(); auto SerializeParams = [&](const std::string& path) {
std::sort(parameters.begin(), parameters.end()); auto IsPersistable = [](const framework::VarDesc* var) {
if (var->Persistable() &&
auto SerializeParams = [&]() -> std::string { var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
std::ostringstream os; var->GetType() != framework::proto::VarType::FETCH_LIST &&
phi::CPUContext ctx; var->GetType() != framework::proto::VarType::RAW) {
for (const auto& param : parameters) { return true;
PADDLE_ENFORCE_NOT_NULL( }
scope_.FindVar(param), return false;
platform::errors::NotFound( };
"Block should already have a '%s' variable", param)); framework::ProgramDesc save_program;
auto* tensor = scope_.FindVar(param)->GetMutable<phi::DenseTensor>(); auto* save_block = save_program.MutableBlock(0);
framework::SerializeToStream(os, *tensor, ctx);
const auto& global_block = mixed_program_desc.Block(0);
std::vector<std::string> save_var_list;
for (framework::VarDesc* var : global_block.AllVars()) {
if (IsPersistable(var)) {
framework::VarDesc* new_var = save_block->Var(var->Name());
new_var->SetShape(var->GetShape());
new_var->SetDataType(var->GetDataType());
new_var->SetType(var->GetType());
new_var->SetLoDLevel(var->GetLoDLevel());
new_var->SetPersistable(true);
save_var_list.push_back(new_var->Name());
}
} }
return os.str(); std::sort(save_var_list.begin(), save_var_list.end());
auto* op = save_block->AppendOp();
op->SetType("save_combine");
op->SetInput("X", save_var_list);
op->SetAttr("file_path", path);
op->CheckAttrs();
framework::Executor exe(platform::CPUPlace{});
exe.Run(save_program, &scope_, 0, true, true);
}; };
auto StrToBinary = [](const std::string& path, const std::string& str) { auto SerializeProg = [&](const std::string& path) {
auto str = mixed_program_desc.Proto()->SerializeAsString();
std::ofstream file(path.c_str(), std::ios::binary); std::ofstream file(path.c_str(), std::ios::binary);
file.write(str.c_str(), str.size()); file.write(str.c_str(), str.size());
file.close(); file.close();
}; };
StrToBinary(mixed_model_file_, SerializeProg(mixed_model_file_);
mixed_program_desc.Proto()->SerializeAsString()); SerializeParams(mixed_params_file_);
StrToBinary(mixed_params_file_, SerializeParams());
} }
bool OpSupportPrecision(const std::string& op_type, bool OpSupportPrecision(const std::string& op_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册