未验证 提交 eba057e3 编写于 作者: C csy0225 提交者: GitHub

Convert to mixed precision support serialize params if a origin model doesn't have params. (#52994)

上级 c29dc34e
......@@ -13,7 +13,8 @@ cc_library(
cc_library(
convert_to_mixed_precision
SRCS convert_to_mixed_precision.cc
DEPS analysis_pass ir_graph_build_pass auto_mixed_precision_pass)
DEPS analysis_pass ir_graph_build_pass auto_mixed_precision_pass
constant_folding_pass)
cc_library(
ir_params_sync_among_devices_pass
SRCS ir_params_sync_among_devices_pass.cc
......
......@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
#include "paddle/fluid/framework/ir/constant_folding_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/phi/common/backend.h"
......@@ -71,8 +72,13 @@ ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass(
void ConvertToMixedPrecisionPass::LoadModel() {
framework::Executor exe{platform::CPUPlace{}};
auto program_desc = inference::Load(&exe, &scope_, model_file_, params_file_);
// If we did not find the provided weight path,
// we assume that the model to be converted only has a model file and no
// params file, we believe this situation is reasonable. In this case, weight
// data may not be loaded.
bool load_params = !params_file_.empty();
auto program_desc =
inference::Load(&exe, &scope_, model_file_, params_file_, load_params);
main_graph_ = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc));
main_graph_->SetNotOwned(framework::ir::kParamScopeAttr, &scope_);
......@@ -81,6 +87,8 @@ void ConvertToMixedPrecisionPass::LoadModel() {
void ConvertToMixedPrecisionPass::Run() {
LoadModel();
framework::ir::ConstantFoldingPass constant_folding_pass;
constant_folding_pass.Apply(main_graph_.get());
framework::ir::AutoMixedPrecisionPass pass;
pass.Set("mixed_precision_mode", new int{static_cast<int>(mixed_precision_)});
if (backend_ == phi::Backend::GPU) {
......@@ -117,6 +125,7 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
const auto& global_block = mixed_program_desc.Block(0);
std::vector<std::string> save_var_list;
bool has_persistable_var = false;
for (framework::VarDesc* var : global_block.AllVars()) {
if (IsPersistable(var)) {
framework::VarDesc* new_var = save_block->Var(var->Name());
......@@ -127,13 +136,35 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
new_var->SetPersistable(true);
save_var_list.push_back(new_var->Name());
has_persistable_var = true;
}
}
std::string save_params_path = path;
if (save_params_path.empty() && has_persistable_var) {
LOG(WARNING)
<< "The [SerializeParams] function did not find the provided weight "
"path, "
"so we assume that the model to be converted only has a model "
"file and no params file, "
"we believe this situation is reasonable. After constant folding, "
"a weight file will be generated, which is saved in the same "
"level file directory "
"as the model file by default and ends in pdiparams.";
save_params_path = mixed_model_file_;
std::string::size_type pos = save_params_path.rfind(".pdmodel");
if (pos != std::string::npos) {
save_params_path.replace(pos, 8, ".pdiparams");
LOG(WARNING) << " The storage path of the converted mixed-precision "
"params has been created: ["
<< save_params_path << "]";
}
}
std::sort(save_var_list.begin(), save_var_list.end());
auto* op = save_block->AppendOp();
op->SetType("save_combine");
op->SetInput("X", save_var_list);
op->SetAttr("file_path", path);
op->SetAttr("file_path", save_params_path);
op->CheckAttrs();
framework::Executor exe(platform::CPUPlace{});
......
......@@ -94,9 +94,14 @@ def convert_to_mixed_precision(
black_list: Operators that do not convert precision.
'''
mixed_model_dirname = os.path.dirname(mixed_model_file)
mixed_params_dirname = os.path.dirname(mixed_params_file)
if not os.path.exists(mixed_model_dirname):
os.makedirs(mixed_model_dirname)
# Support mixed_params_file is empty, because some models don't have params, but convert_to_mixed_precision will call
# constant_folding_pass, it will generate a new params file to save persistable vars, which is saved in the same
# level file directory as the model file by default and ends in pdiparams.
mixed_params_dirname = (
os.path.dirname(mixed_params_file)
if len(mixed_params_file) != 0
else mixed_model_dirname
)
if not os.path.exists(mixed_params_dirname):
os.makedirs(mixed_params_dirname)
convert_to_mixed_precision_bind(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册