提交 fd07242c 编写于 作者: S sangoly 提交者: Yan Chunwei

[Protobuf] add combined-param model save/load supported test=develop (#1876)

上级 32033452
......@@ -57,8 +57,11 @@ TEST(CXXApi_LightApi, optim_model) {
});
// On ARM devices, the preferred X86 target not works, but it can still
// select ARM kernels.
cxx_api.Build(
FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places);
cxx_api.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kX86), PRECISION(kFloat)},
valid_places);
cxx_api.SaveModel(FLAGS_optimized_model);
}
......@@ -75,8 +78,11 @@ TEST(CXXApi_LightApi, save_and_load_model) {
});
// On ARM devices, the preferred X86 target not works, but it can still
// select ARM kernels.
cxx_api.Build(
FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places);
cxx_api.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kX86), PRECISION(kFloat)},
valid_places);
auto* x = cxx_api.GetInput(0);
SetConstInput(x);
......
......@@ -33,7 +33,7 @@ void Predictor::SaveModel(const std::string &dir,
program_->SaveOpInfosToProgram(&program_desc_);
switch (model_type) {
case lite_api::LiteModelType::kProtobuf:
SaveModelPb(dir, *program_->exec_scope(), program_desc_);
SaveModelPb(dir, *program_->exec_scope(), program_desc_, true);
break;
case lite_api::LiteModelType::kNaiveBuffer:
SaveModelNaive(dir, *program_->exec_scope(), program_desc_);
......@@ -84,16 +84,29 @@ const cpp::ProgramDesc &Predictor::program_desc() const {
const RuntimeProgram &Predictor::runtime_program() const { return *program_; }
void Predictor::Build(const std::string &model_path,
const std::string model_file,
const std::string param_file,
const Place &prefer_place,
const std::vector<Place> &valid_places,
const std::vector<std::string> &passes,
lite_api::LiteModelType model_type) {
LOG(INFO) << "Load model from " << model_path;
switch (model_type) {
case lite_api::LiteModelType::kProtobuf:
LoadModelPb(model_path, scope_.get(), &program_desc_);
break;
case lite_api::LiteModelType::kProtobuf: {
bool combined_param = false;
if (!model_file.empty() && !param_file.empty()) {
combined_param = true;
}
LoadModelPb(model_path,
model_file,
param_file,
scope_.get(),
&program_desc_,
combined_param);
} break;
case lite_api::LiteModelType::kNaiveBuffer:
CHECK(!model_path.empty())
<< "NaiveBuffer backend only supported combined param";
LoadModelNaive(model_path, scope_.get(), &program_desc_);
break;
default:
......
......@@ -41,6 +41,8 @@ class LITE_API Predictor {
// Build from a model, with places set for hardware config.
void Build(
const std::string& model_path,
const std::string model_file_path,
const std::string param_file_path,
const Place& prefer_place,
const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {},
......
......@@ -41,7 +41,7 @@ void Run(const char* model_dir, int repeat) {
});
predictor.Build(
model_dir, Place{TARGET(kARM), PRECISION(kInt8)}, valid_places);
model_dir, "", "", Place{TARGET(kARM), PRECISION(kInt8)}, valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
......@@ -47,7 +47,11 @@ CxxPaddleApiImpl::CxxPaddleApiImpl() {}
void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
auto places = config.valid_places();
places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny));
raw_predictor_.Build(config.model_dir(), config.preferred_place(), places);
raw_predictor_.Build(config.model_dir(),
config.model_file(),
config.param_file(),
config.preferred_place(),
places);
}
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
......
......@@ -45,8 +45,11 @@ TEST(CXXApi, save_model) {
lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}});
predictor.Build(
FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places);
predictor.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kCUDA), PRECISION(kFloat)},
valid_places);
LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model;
predictor.SaveModel(FLAGS_optimized_model,
......@@ -93,8 +96,11 @@ TEST(CXXApi, save_model) {
lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}});
predictor.Build(
FLAGS_model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places);
predictor.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kARM), PRECISION(kFloat)},
valid_places);
LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model;
predictor.SaveModel(FLAGS_optimized_model);
......@@ -107,6 +113,8 @@ TEST(CXXApi, load_model_naive) {
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}});
predictor.Build(FLAGS_optimized_model + ".naive",
"",
"",
Place{TARGET(kARM), PRECISION(kFloat)},
valid_places,
{},
......
......@@ -31,7 +31,7 @@ void TestModel(const std::vector<Place> &valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, preferred_place, valid_places);
predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto *input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
......@@ -33,8 +33,11 @@ TEST(InceptionV4, test) {
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}});
predictor.Build(
FLAGS_model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places);
predictor.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kARM), PRECISION(kFloat)},
valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
......@@ -24,7 +24,7 @@ void LightPredictor::Build(const std::string& model_dir,
switch (model_type) {
#ifndef LITE_ON_TINY_PUBLISH
case lite_api::LiteModelType::kProtobuf:
LoadModelPb(model_dir, scope_.get(), &desc);
LoadModelPb(model_dir, "", "", scope_.get(), &desc);
break;
#endif
case lite_api::LiteModelType::kNaiveBuffer:
......
......@@ -38,6 +38,8 @@ const lite::Tensor* RunHvyModel() {
#endif
predictor.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda
valid_places);
......
......@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, preferred_place, valid_places);
predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
......@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, preferred_place, valid_places);
predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 300, 300})));
......
......@@ -36,7 +36,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
predictor.Build(model_dir, preferred_place, valid_places);
predictor.Build(model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
......@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, preferred_place, valid_places);
predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 608, 608})));
......
......@@ -37,7 +37,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
predictor.Build(model_dir, preferred_place, valid_places);
predictor.Build(model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
......@@ -23,7 +23,12 @@
#include "lite/utils/cp_logging.h"
#include "lite/utils/string.h"
DEFINE_string(model_dir, "", "path of the model");
DEFINE_string(model_dir,
"",
"path of the model. This option will be ignored if model_file "
"and param_file are exist");
DEFINE_string(model_file, "", "model file path of the combined-param model");
DEFINE_string(param_file, "", "param file path of the combined-param model");
DEFINE_string(
optimize_out_type,
"protobuf",
......@@ -39,8 +44,15 @@ namespace paddle {
namespace lite_api {
void Main() {
if (!FLAGS_model_file.empty() && !FLAGS_param_file.empty()) {
LOG(WARNING)
<< "Load combined-param model. Option model_dir will be ignored";
}
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_model_file(FLAGS_model_file);
config.set_param_file(FLAGS_param_file);
std::vector<Place> valid_places;
auto target_reprs = lite::Split(FLAGS_valid_targets, " ");
......
......@@ -39,7 +39,7 @@ TEST(model, test) {
precision = PRECISION(kInt8);
}
predictor.Build(
FLAGS_model_dir, Place{TARGET(kARM), precision}, valid_places);
FLAGS_model_dir, "", "", Place{TARGET(kARM), precision}, valid_places);
int im_width = FLAGS_im_width;
int im_height = FLAGS_im_height;
auto* input_tensor = predictor.GetInput(0);
......
......@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, preferred_place, valid_places);
predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 1, 48, 512})));
......
......@@ -99,13 +99,19 @@ class LITE_API ConfigBase {
class LITE_API CxxConfig : public ConfigBase {
Place preferred_place_;
std::vector<Place> valid_places_;
std::string model_file_;
std::string param_file_;
public:
void set_preferred_place(const Place& x) { preferred_place_ = x; }
void set_valid_places(const std::vector<Place>& x) { valid_places_ = x; }
void set_model_file(const std::string& path) { model_file_ = path; }
void set_param_file(const std::string& path) { param_file_ = path; }
const Place& preferred_place() const { return preferred_place_; }
const std::vector<Place>& valid_places() const { return valid_places_; }
std::string model_file() const { return model_file_; }
std::string param_file() const { return param_file_; }
};
/// MobileConfig is the config for the light weight predictor, it will skip
......
......@@ -31,8 +31,11 @@ TEST(ResNet18, test) {
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}});
predictor.Build(
FLAGS_model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places);
predictor.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kARM), PRECISION(kFloat)},
valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
......@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, preferred_place, valid_places);
predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
......@@ -33,6 +33,8 @@ TEST(ResNet50, test) {
Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNHWC)}});
predictor.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)},
valid_places);
......
......@@ -31,7 +31,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, preferred_place, valid_places);
predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim((std::vector<DDim::value_type>({1, 3, 224, 224}))));
......
......@@ -51,7 +51,7 @@ TEST(CXXApi, test_lite_googlenet) {
// LOG(INFO)<<"FLAGS_eval_googlenet_dir:"<<FLAGS_test_lite_googlenet_dir;
std::string model_dir = FLAGS_model_dir;
predictor.Build(
model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places);
model_dir, "", "", Place{TARGET(kX86), PRECISION(kFloat)}, valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
......@@ -55,8 +55,12 @@ TEST(InceptionV4, test_inceptionv4_lite_x86) {
"io_copy_kernel_pick_pass",
"variable_place_inference_pass",
"runtime_context_assign_pass"});
predictor.Build(
model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places, passes);
predictor.Build(model_dir,
"",
"",
Place{TARGET(kX86), PRECISION(kFloat)},
valid_places,
passes);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
......@@ -54,8 +54,12 @@ TEST(Mobilenet_v1, test_mobilenetv1_lite_x86) {
"io_copy_kernel_pick_pass",
"variable_place_inference_pass",
"runtime_context_assign_pass"});
predictor.Build(
model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places, passes);
predictor.Build(model_dir,
"",
"",
Place{TARGET(kX86), PRECISION(kFloat)},
valid_places,
passes);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
auto* data = input_tensor->mutable_data<float>();
......
......@@ -56,8 +56,12 @@ TEST(Mobilenet_v2, test_mobilenetv2_lite_x86) {
"io_copy_kernel_pick_pass",
"variable_place_inference_pass",
"runtime_context_assign_pass"});
predictor.Build(
model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places, passes);
predictor.Build(model_dir,
"",
"",
Place{TARGET(kX86), PRECISION(kFloat)},
valid_places,
passes);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
......@@ -33,8 +33,11 @@ TEST(unet, test) {
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}});
predictor.Build(
FLAGS_model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places);
predictor.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kARM), PRECISION(kFloat)},
valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 512, 512})));
......
......@@ -44,6 +44,8 @@ TEST(fc_fuse_pass, fuse_test) {
#endif
predictor.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda
valid_places);
......@@ -72,8 +74,11 @@ TEST(fc_fuse_pass, save_model_test) {
lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}});
predictor.Build(
FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places);
predictor.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kX86), PRECISION(kFloat)},
valid_places);
LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model;
predictor.SaveModel(FLAGS_optimized_model);
......
......@@ -40,7 +40,7 @@ void TestModel(lite::Predictor* predictor,
const std::vector<Place>& valid_places,
const std::string& model_dir) {
predictor->Build(
model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places);
model_dir, "", "", Place{TARGET(kARM), PRECISION(kFloat)}, valid_places);
auto* input_tensor = predictor->GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>(
......
......@@ -32,7 +32,7 @@ namespace lite {
TEST(SubgraphTest, mobilenetv2) {
cpp::ProgramDesc program_desc;
auto scope = std::make_shared<Scope>();
LoadModelPb(FLAGS_model_dir, scope.get(), &program_desc);
LoadModelPb(FLAGS_model_dir, "", "", scope.get(), &program_desc);
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)},
#ifdef LITE_WITH_ARM
......
......@@ -145,7 +145,10 @@ TEST(gen_code, auto_gen) {
TEST(gen_code, optimized_program) {
lite::Scope scope;
cpp::ProgramDesc cpp_desc;
LoadModelPb(FLAGS_optimized_model, &scope, &cpp_desc);
std::string model_file = FLAGS_optimized_model + "/model";
std::string param_file = FLAGS_optimized_model + "/params";
LoadModelPb(
FLAGS_optimized_model, model_file, param_file, &scope, &cpp_desc, true);
framework::proto::ProgramDesc pb_proto_desc;
lite::pb::ProgramDesc pb_desc(&pb_proto_desc);
......
......@@ -27,7 +27,9 @@ namespace gencode {
void GenCode(const std::string& model_dir, const std::string& out_file) {
lite::Scope scope;
cpp::ProgramDesc cpp_desc;
LoadModelPb(model_dir, &scope, &cpp_desc);
std::string model_file = model_dir + "/model";
std::string param_file = model_dir + "/params";
LoadModelPb(model_dir, model_file, param_file, &scope, &cpp_desc, true);
framework::proto::ProgramDesc pb_proto_desc;
lite::pb::ProgramDesc pb_desc(&pb_proto_desc);
......
......@@ -159,15 +159,63 @@ void LoadParam(const std::string &path, Variable *out) {
LoadLoDTensor(fin, out);
}
//
bool IsPersistable(const cpp::VarDesc &var) {
if (var.Persistable() && var.GetType() != VarDescAPI::Type::FEED_MINIBATCH &&
var.GetType() != VarDescAPI::Type::FETCH_LIST &&
var.GetType() != VarDescAPI::Type::RAW) {
return true;
}
return false;
}
void LoadCombinedParamsPb(const std::string &path,
lite::Scope *scope,
const cpp::ProgramDesc &cpp_prog) {
CHECK(scope);
auto prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
// Get vars
std::vector<std::string> paramlist;
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (!IsPersistable(var)) continue;
paramlist.push_back(var.Name());
}
std::sort(paramlist.begin(), paramlist.end());
// Load vars
std::ifstream file(path);
CHECK(file.is_open());
for (size_t i = 0; i < paramlist.size(); ++i) {
auto *var = scope->Var(paramlist[i]);
// Error checking
CHECK(static_cast<bool>(file))
<< "There is a problem with loading model parameters";
LoadLoDTensor(file, var);
}
file.peek();
CHECK(file.eof()) << "You are not allowed to load partial data via"
<< " LoadCombinedParamsPb, use LoadParam instead.";
file.close();
}
void LoadModelPb(const std::string &model_dir,
const std::string &model_file,
const std::string &param_file,
Scope *scope,
cpp::ProgramDesc *cpp_prog) {
cpp::ProgramDesc *cpp_prog,
bool combined) {
CHECK(cpp_prog);
CHECK(scope);
cpp_prog->ClearBlocks();
// Load model
const std::string prog_path = model_dir + "/__model__";
std::string prog_path = model_dir + "/__model__";
if (combined) {
prog_path = model_file;
}
framework::proto::ProgramDesc pb_proto_prog = *LoadProgram(prog_path);
pb::ProgramDesc pb_prog(&pb_proto_prog);
......@@ -176,6 +224,9 @@ void LoadModelPb(const std::string &model_dir,
// Load Params
// NOTE: Only main block be used now.
if (combined) {
LoadCombinedParamsPb(param_file, scope, *cpp_prog);
} else {
auto main_block = pb_proto_prog.blocks(0);
for (auto &var : main_block.vars()) {
if (var.name() == "feed" || var.name() == "fetch" || !var.persistable())
......@@ -193,6 +244,8 @@ void LoadModelPb(const std::string &model_dir,
CHECK(false) << "unknown weight type";
}
}
}
#ifdef LITE_WITH_NPU
for (auto &op : main_block.ops()) {
LOG(INFO) << "op type:" << op.type();
......@@ -216,14 +269,18 @@ void LoadModelPb(const std::string &model_dir,
void SaveModelPb(const std::string &model_dir,
const Scope &exec_scope,
const cpp::ProgramDesc &cpp_prog) {
const cpp::ProgramDesc &cpp_prog,
bool combined) {
MkDirRecur(model_dir);
// Save program
framework::proto::ProgramDesc pb_proto_prog;
pb::ProgramDesc pb_prog(&pb_proto_prog);
TransformProgramDescCppToAny(cpp_prog, &pb_prog);
const std::string prog_path = model_dir + "/__model__";
std::string prog_path = model_dir + "/__model__";
if (combined) {
prog_path = model_dir + "/model";
}
std::ofstream model_ostream(prog_path, std::ios_base::binary);
CHECK(model_ostream.is_open());
const std::string pb_str = pb_proto_prog.SerializeAsString();
......@@ -232,8 +289,13 @@ void SaveModelPb(const std::string &model_dir,
// Save Params
// NOTE: Only main block be used now.
if (combined) {
const std::string combined_params_path = model_dir + "/params";
SaveCombinedParamsPb(combined_params_path, exec_scope, cpp_prog);
} else {
for (auto &item : pb_proto_prog.blocks(0).vars()) {
if (item.name() == "feed" || item.name() == "fetch" || !item.persistable())
if (item.name() == "feed" || item.name() == "fetch" ||
!item.persistable())
continue;
const std::string path = model_dir + "/" + item.name();
std::ofstream var_ostream(path, std::ios::binary);
......@@ -241,9 +303,34 @@ void SaveModelPb(const std::string &model_dir,
SerializeTensor(var_ostream, exec_scope, item.name());
var_ostream.close();
}
}
VLOG(4) << "Save protobuf model in '" << model_dir << "'' successfully";
}
void SaveCombinedParamsPb(const std::string &path,
const lite::Scope &exec_scope,
const cpp::ProgramDesc &cpp_prog) {
auto prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
// Get vars
std::vector<std::string> paramlist;
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (!IsPersistable(var)) continue;
paramlist.push_back(var.Name());
}
std::sort(paramlist.begin(), paramlist.end());
// Load vars
std::ofstream file(path);
CHECK(file.is_open());
for (size_t i = 0; i < paramlist.size(); ++i) {
SerializeTensor(file, exec_scope, paramlist[i]);
}
file.close();
}
void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
// the 1st field, uint32_t version
constexpr uint32_t version = 0;
......
......@@ -41,14 +41,26 @@ void LoadParams(const std::string& path);
// Load a single parameter to an output tensor.
void LoadParam(const std::string& path, Variable* out);
void LoadCombinedParamsPb(const std::string& path,
lite::Scope* scope,
const cpp::ProgramDesc& prog);
// Read a model and files of parameters in pb format.
void LoadModelPb(const std::string& model_dir,
const std::string& model_file,
const std::string& param_file,
Scope* scope,
cpp::ProgramDesc* prog);
cpp::ProgramDesc* prog,
bool combined = false);
// Save a model and files of parameters in pb format.
void SaveModelPb(const std::string& model_dir,
const Scope& scope,
const cpp::ProgramDesc& prog,
bool combined = false);
void SaveCombinedParamsPb(const std::string& path,
const lite::Scope& exec_scope,
const cpp::ProgramDesc& prog);
// Serialize tensors to ostream.
......
......@@ -40,18 +40,38 @@ TEST(ModelParser, LoadModelPb) {
CHECK(!FLAGS_model_dir.empty());
cpp::ProgramDesc prog;
Scope scope;
LoadModelPb(FLAGS_model_dir, &scope, &prog);
LoadModelPb(FLAGS_model_dir, "", "", &scope, &prog);
}
TEST(ModelParser, SaveModelPb) {
CHECK(!FLAGS_model_dir.empty());
cpp::ProgramDesc prog;
Scope scope;
LoadModelPb(FLAGS_model_dir, &scope, &prog);
LoadModelPb(FLAGS_model_dir, "", "", &scope, &prog);
const std::string save_pb_model_path = FLAGS_model_dir + ".saved.pb";
SaveModelPb(save_pb_model_path, scope, prog);
}
TEST(ModelParser, SaveModelCombinedPb) {
CHECK(!FLAGS_model_dir.empty());
cpp::ProgramDesc prog;
Scope scope;
LoadModelPb(FLAGS_model_dir, "", "", &scope, &prog);
const std::string save_pb_model_path = FLAGS_model_dir + ".saved.pb.combined";
SaveModelPb(save_pb_model_path, scope, prog, true);
}
TEST(ModelParser, LoadModelCombinedPb) {
CHECK(!FLAGS_model_dir.empty());
const std::string model_path = FLAGS_model_dir + ".saved.pb.combined";
cpp::ProgramDesc prog;
Scope scope;
std::string model_file_path = FLAGS_model_dir + ".saved.pb.combined/model";
std::string param_file_path = FLAGS_model_dir + ".saved.pb.combined/params";
LoadModelPb(
model_path, model_file_path, param_file_path, &scope, &prog, true);
}
TEST(ModelParser, SaveParamNaive) {
Scope scope;
auto* tensor = scope.Var("xxx")->GetMutable<lite::Tensor>();
......@@ -94,7 +114,7 @@ TEST(ModelParser, SaveModelNaive) {
CHECK(!FLAGS_model_dir.empty());
cpp::ProgramDesc prog;
Scope scope;
LoadModelPb(FLAGS_model_dir, &scope, &prog);
LoadModelPb(FLAGS_model_dir, "", "", &scope, &prog);
const std::string save_pb_model_path = FLAGS_model_dir + ".saved.naive";
SaveModelNaive(save_pb_model_path, scope, prog);
}
......
......@@ -61,6 +61,8 @@ void Run(DebugConfig* conf) {
}};
predictor.Build(conf->model_dir,
"",
"",
#ifdef LITE_WITH_ARM
Place{TARGET(kARM), PRECISION(kFloat)},
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册