未验证 提交 b835d958 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix convert_to_mixed_precision api save model bug (#52767)

* update save model

* update
上级 9a7c83bd
......@@ -102,32 +102,53 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
framework::ProgramDesc mixed_program_desc;
framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc);
auto parameters = scope_.LocalVarNames();
std::sort(parameters.begin(), parameters.end());
auto SerializeParams = [&]() -> std::string {
std::ostringstream os;
phi::CPUContext ctx;
for (const auto& param : parameters) {
PADDLE_ENFORCE_NOT_NULL(
scope_.FindVar(param),
platform::errors::NotFound(
"Block should already have a '%s' variable", param));
auto* tensor = scope_.FindVar(param)->GetMutable<phi::DenseTensor>();
framework::SerializeToStream(os, *tensor, ctx);
auto SerializeParams = [&](const std::string& path) {
auto IsPersistable = [](const framework::VarDesc* var) {
if (var->Persistable() &&
var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
var->GetType() != framework::proto::VarType::FETCH_LIST &&
var->GetType() != framework::proto::VarType::RAW) {
return true;
}
return false;
};
framework::ProgramDesc save_program;
auto* save_block = save_program.MutableBlock(0);
const auto& global_block = mixed_program_desc.Block(0);
std::vector<std::string> save_var_list;
for (framework::VarDesc* var : global_block.AllVars()) {
if (IsPersistable(var)) {
framework::VarDesc* new_var = save_block->Var(var->Name());
new_var->SetShape(var->GetShape());
new_var->SetDataType(var->GetDataType());
new_var->SetType(var->GetType());
new_var->SetLoDLevel(var->GetLoDLevel());
new_var->SetPersistable(true);
save_var_list.push_back(new_var->Name());
}
}
return os.str();
std::sort(save_var_list.begin(), save_var_list.end());
auto* op = save_block->AppendOp();
op->SetType("save_combine");
op->SetInput("X", save_var_list);
op->SetAttr("file_path", path);
op->CheckAttrs();
framework::Executor exe(platform::CPUPlace{});
exe.Run(save_program, &scope_, 0, true, true);
};
auto StrToBinary = [](const std::string& path, const std::string& str) {
auto SerializeProg = [&](const std::string& path) {
auto str = mixed_program_desc.Proto()->SerializeAsString();
std::ofstream file(path.c_str(), std::ios::binary);
file.write(str.c_str(), str.size());
file.close();
};
StrToBinary(mixed_model_file_,
mixed_program_desc.Proto()->SerializeAsString());
StrToBinary(mixed_params_file_, SerializeParams());
SerializeProg(mixed_model_file_);
SerializeParams(mixed_params_file_);
}
bool OpSupportPrecision(const std::string& op_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册