提交 25dfe6b1 编写于 作者: H huzhiqiang 提交者: GitHub

【MODEL UPDATE】: combine params and model (#2800)

上级 b9c3a45d
......@@ -296,10 +296,10 @@ if (LITE_ON_TINY_PUBLISH)
endif()
if (LITE_ON_MODEL_OPTIMIZE_TOOL)
message(STATUS "Compiling OPT")
lite_cc_binary(OPT SRCS opt.cc cxx_api_impl.cc paddle_api.cc cxx_api.cc
message(STATUS "Compiling opt")
lite_cc_binary(opt SRCS opt.cc cxx_api_impl.cc paddle_api.cc cxx_api.cc
DEPS gflags kernel op optimizer mir_passes utils)
add_dependencies(OPT op_list_h kernel_list_h all_kernel_faked_cc supported_kernel_op_info_h)
add_dependencies(opt op_list_h kernel_list_h all_kernel_faked_cc supported_kernel_op_info_h)
endif(LITE_ON_MODEL_OPTIMIZE_TOOL)
lite_cc_test(test_paddle_api SRCS paddle_api_test.cc DEPS paddle_api_full paddle_api_light
......
......@@ -181,6 +181,7 @@ inline MobileConfig jmobileconfig_to_cpp_mobileconfig(JNIEnv *env,
MobileConfig config;
// set model dir
// NOTE: This is a deprecated API and will be removed in latter release.
jmethodID model_dir_method = env->GetMethodID(
mobileconfig_jclazz, "getModelDir", "()Ljava/lang/String;");
jstring java_model_dir =
......@@ -190,6 +191,27 @@ inline MobileConfig jmobileconfig_to_cpp_mobileconfig(JNIEnv *env,
config.set_model_dir(cpp_model_dir);
}
// set model from file
jmethodID model_file_method = env->GetMethodID(
mobileconfig_jclazz, "getModelFromFile", "()Ljava/lang/String;");
jstring java_model_file =
(jstring)env->CallObjectMethod(jmobileconfig, model_file_method);
if (java_model_file != nullptr) {
std::string cpp_model_file = jstring_to_cpp_string(env, java_model_file);
config.set_model_from_file(cpp_model_file);
}
// set model from buffer
jmethodID model_buffer_method = env->GetMethodID(
mobileconfig_jclazz, "getModelFromBuffer", "()Ljava/lang/String;");
jstring java_model_buffer =
(jstring)env->CallObjectMethod(jmobileconfig, model_buffer_method);
if (java_model_buffer != nullptr) {
std::string cpp_model_buffer =
jstring_to_cpp_string(env, java_model_buffer);
config.set_model_from_buffer(cpp_model_buffer);
}
// set threads
jmethodID threads_method =
env->GetMethodID(mobileconfig_jclazz, "getThreads", "()I");
......
......@@ -64,6 +64,44 @@ public class MobileConfig extends ConfigBase {
return powerMode.value();
}
/**
* Set model from file.
*
* @return
*/
public void setModelFromFile(String modelFile) {
this.liteModelFile = modelFile;
}
/**
* Returns name of model_file.
*
* @return liteModelFile
*/
public String getModelFile() {
return liteModelFile;
}
/**
* Set model from buffer.
*
* @return
*/
public void setModelFromBuffer(String modelBuffer) {
this.liteModelBuffer = modelBuffer;
}
/**
* Returns model buffer
*
* @return liteModelBuffer
*/
public String getModelBuffer() {
return liteModelBuffer;
}
private PowerMode powerMode = PowerMode.LITE_POWER_HIGH;
private int threads = 1;
private String liteModelFile;
private String liteModelBuffer;
}
......@@ -62,7 +62,7 @@ TEST(CXXApi_LightApi, optim_model) {
TEST(CXXApi_LightApi, save_and_load_model) {
lite::Predictor cxx_api;
lite::LightPredictor light_api(FLAGS_optimized_model);
lite::LightPredictor light_api(FLAGS_optimized_model + ".nb", false);
// CXXAPi
{
......
......@@ -116,7 +116,7 @@ void Run(const std::vector<std::vector<int64_t>>& input_shapes,
lite_api::MobileConfig config;
config.set_threads(FLAGS_threads);
config.set_power_mode(static_cast<PowerMode>(FLAGS_power_mode));
config.set_model_dir(model_dir);
config.set_model_from_file(model_dir + ".nb");
auto predictor = lite_api::CreatePaddlePredictor(config);
......
......@@ -239,7 +239,7 @@ void Predictor::Build(const std::string &model_path,
case lite_api::LiteModelType::kNaiveBuffer:
CHECK(!model_path.empty())
<< "NaiveBuffer backend only supported combined param";
LoadModelNaive(model_path, scope_.get(), &program_desc_);
LoadModelNaiveFromFile(model_path, scope_.get(), &program_desc_);
break;
default:
LOG(FATAL) << "Unknown model type";
......
......@@ -101,7 +101,7 @@ TEST(CXXApi, save_model) {
TEST(CXXApi, load_model_naive) {
lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kARM), PRECISION(kFloat)}});
predictor.Build(FLAGS_optimized_model + ".naive",
predictor.Build(FLAGS_optimized_model + ".naive.nb",
"",
"",
valid_places,
......
......@@ -18,6 +18,17 @@
namespace paddle {
namespace lite {
void LightPredictor::Build(const std::string& lite_model_file,
bool model_from_memory) {
if (model_from_memory) {
LoadModelNaiveFromMemory(lite_model_file, scope_.get(), &cpp_program_desc_);
} else {
LoadModelNaiveFromFile(lite_model_file, scope_.get(), &cpp_program_desc_);
}
BuildRuntimeProgram(cpp_program_desc_);
PrepareFeedFetch();
}
void LightPredictor::Build(const std::string& model_dir,
const std::string& model_buffer,
const std::string& param_buffer,
......
......@@ -18,6 +18,7 @@
*/
#pragma once
#include <algorithm>
#include <map>
#include <memory>
#include <string>
......@@ -39,12 +40,22 @@ namespace lite {
*/
class LITE_API LightPredictor {
public:
LightPredictor(
const std::string& model_dir,
const std::string& model_buffer = "",
const std::string& param_buffer = "",
bool model_from_memory = false,
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf) {
// constructor function of LightPredictor, `lite_model_file` refers to data in
// model file or buffer,`model_from_memory` refers to whther to load model
// from memory.
LightPredictor(const std::string& lite_model_file,
bool model_from_memory = false) {
scope_ = std::make_shared<Scope>();
Build(lite_model_file, model_from_memory);
}
// NOTE: This is a deprecated API and will be removed in latter release.
LightPredictor(const std::string& model_dir,
const std::string& model_buffer = "",
const std::string& param_buffer = "",
bool model_from_memory = false,
lite_api::LiteModelType model_type =
lite_api::LiteModelType::kNaiveBuffer) {
scope_ = std::make_shared<Scope>();
Build(model_dir, model_buffer, param_buffer, model_type, model_from_memory);
}
......@@ -69,6 +80,10 @@ class LITE_API LightPredictor {
void PrepareFeedFetch();
private:
void Build(const std::string& lite_model_file,
bool model_from_memory = false);
// NOTE: This is a deprecated API and will be removed in latter release.
void Build(
const std::string& model_dir,
const std::string& model_buffer,
......
......@@ -23,13 +23,17 @@ namespace lite {
void LightPredictorImpl::Init(const lite_api::MobileConfig& config) {
// LightPredictor Only support NaiveBuffer backend in publish lib
raw_predictor_.reset(
new LightPredictor(config.model_dir(),
config.model_buffer(),
config.param_buffer(),
config.model_from_memory(),
lite_api::LiteModelType::kNaiveBuffer));
if (config.lite_model_file().empty()) {
raw_predictor_.reset(
new LightPredictor(config.model_dir(),
config.model_buffer(),
config.param_buffer(),
config.model_from_memory(),
lite_api::LiteModelType::kNaiveBuffer));
} else {
raw_predictor_.reset(new LightPredictor(config.lite_model_file(),
config.model_from_memory()));
}
mode_ = config.power_mode();
threads_ = config.threads();
}
......
......@@ -73,7 +73,7 @@ void Run(const std::vector<std::vector<int64_t>>& input_shapes,
const int repeat,
const int warmup_times = 0) {
lite_api::MobileConfig config;
config.set_model_dir(model_dir);
config.set_model_from_file(model_dir + ".nb");
config.set_power_mode(power_mode);
config.set_threads(thread_num);
......
......@@ -17,7 +17,7 @@
#include <gtest/gtest.h>
#endif
// "supported_kernel_op_info.h", "all_kernel_faked.cc" and "kernel_src_map.h"
// are created automatically during OPT's compiling period
// are created automatically during opt's compiling period
#include <iomanip>
#include "all_kernel_faked.cc" // NOLINT
#include "kernel_src_map.h" // NOLINT
......
......@@ -190,5 +190,27 @@ void ConfigBase::set_threads(int threads) {
#endif
}
// set model data in combined format, `set_model_from_file` refers to loading
// model from file, set_model_from_buffer refers to loading model from memory
// buffer
void MobileConfig::set_model_from_file(const std::string &x) {
lite_model_file_ = x;
}
void MobileConfig::set_model_from_buffer(const std::string &x) {
lite_model_file_ = x;
model_from_memory_ = true;
}
void MobileConfig::set_model_buffer(const char *model_buffer,
size_t model_buffer_size,
const char *param_buffer,
size_t param_buffer_size) {
LOG(WARNING) << "warning: `set_model_buffer` will be abandened in "
"release/v3.0.0, new method `set_model_from_buffer(const "
"std::string &x)` is recommended.";
model_buffer_ = std::string(model_buffer, model_buffer + model_buffer_size);
param_buffer_ = std::string(param_buffer, param_buffer + param_buffer_size);
model_from_memory_ = true;
}
} // namespace lite_api
} // namespace paddle
......@@ -168,22 +168,40 @@ class LITE_API CxxConfig : public ConfigBase {
/// MobileConfig is the config for the light weight predictor, it will skip
/// IR optimization or other unnecessary stages.
class LITE_API MobileConfig : public ConfigBase {
// whether to load data from memory. Model data will be loaded from memory
// buffer if model_from_memory_ is true.
bool model_from_memory_{false};
// model data readed from file or memory buffer in combined format.
std::string lite_model_file_;
// NOTE: This is a deprecated variable and will be removed in latter release.
std::string model_buffer_;
std::string param_buffer_;
bool model_from_memory_{false};
public:
// set model data in combined format, `set_model_from_file` refers to loading
// model from file, set_model_from_buffer refers to loading model from memory
// buffer
void set_model_from_file(const std::string& x);
void set_model_from_buffer(const std::string& x);
// return model data in lite_model_file_, which is in combined format.
const std::string& lite_model_file() const { return lite_model_file_; }
// return model_from_memory_, which indicates whether to load model from
// memory buffer.
bool model_from_memory() const { return model_from_memory_; }
// NOTE: This is a deprecated API and will be removed in latter release.
void set_model_buffer(const char* model_buffer,
size_t model_buffer_size,
const char* param_buffer,
size_t param_buffer_size) {
model_buffer_ = std::string(model_buffer, model_buffer + model_buffer_size);
param_buffer_ = std::string(param_buffer, param_buffer + param_buffer_size);
model_from_memory_ = true;
}
size_t param_buffer_size);
bool model_from_memory() const { return model_from_memory_; }
// NOTE: This is a deprecated API and will be removed in latter release.
const std::string& model_buffer() const { return model_buffer_; }
// NOTE: This is a deprecated API and will be removed in latter release.
const std::string& param_buffer() const { return param_buffer_; }
};
......
......@@ -72,7 +72,7 @@ TEST(CxxApi, run) {
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
TEST(LightApi, run) {
lite_api::MobileConfig config;
config.set_model_dir(FLAGS_model_dir + ".opt2.naive");
config.set_model_from_file(FLAGS_model_dir + ".opt2.naive.nb");
auto predictor = lite_api::CreatePaddlePredictor(config);
......@@ -109,16 +109,11 @@ TEST(LightApi, run) {
// Demo2 for Loading model from memory
TEST(MobileConfig, LoadfromMemory) {
// Get naive buffer
auto model_path = std::string(FLAGS_model_dir) + ".opt2.naive/__model__.nb";
auto params_path = std::string(FLAGS_model_dir) + ".opt2.naive/param.nb";
std::string model_buffer = lite::ReadFile(model_path);
size_t size_model = model_buffer.length();
std::string params_buffer = lite::ReadFile(params_path);
size_t size_params = params_buffer.length();
auto model_file = std::string(FLAGS_model_dir) + ".opt2.naive.nb";
std::string model_buffer = lite::ReadFile(model_file);
// set model buffer and run model
lite_api::MobileConfig config;
config.set_model_buffer(
model_buffer.c_str(), size_model, params_buffer.c_str(), size_params);
config.set_model_from_buffer(model_buffer);
auto predictor = lite_api::CreatePaddlePredictor(config);
auto input_tensor = predictor->GetInput(0);
......
......@@ -116,6 +116,8 @@ void BindLiteMobileConfig(py::module *m) {
py::class_<MobileConfig> mobile_config(*m, "MobileConfig");
mobile_config.def(py::init<>())
.def("set_model_from_file", &MobileConfig::set_model_from_file)
.def("set_model_from_buffer", &MobileConfig::set_model_from_buffer)
.def("set_model_dir", &MobileConfig::set_model_dir)
.def("model_dir", &MobileConfig::model_dir)
.def("set_model_buffer", &MobileConfig::set_model_buffer)
......
......@@ -157,7 +157,7 @@ std::shared_ptr<lite_api::PaddlePredictor> TestModel(
lite_api::LiteModelType::kNaiveBuffer);
// Load optimized model
lite_api::MobileConfig mobile_config;
mobile_config.set_model_dir(optimized_model_dir);
mobile_config.set_model_from_file(optimized_model_dir + ".nb");
mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH);
mobile_config.set_threads(1);
predictor = lite_api::CreatePaddlePredictor(mobile_config);
......
......@@ -42,7 +42,7 @@ static std::string version() {
std::string tag = paddlelite_tag();
if (tag.empty()) {
ss << paddlelite_branch() << "(" << paddlelite_commit() << ")";
ss << paddlelite_commit();
} else {
ss << tag;
}
......
......@@ -20,6 +20,7 @@
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/core/variable.h"
#include "lite/core/version.h"
#include "lite/model_parser/desc_apis.h"
#include "lite/model_parser/naive_buffer/combined_params_desc.h"
#include "lite/model_parser/naive_buffer/param_desc.h"
......@@ -536,7 +537,7 @@ void SaveCombinedParamsNaive(const std::string &path,
}
pt_desc.Save();
table.SaveToFile(path);
table.AppendToFile(path);
}
void SaveModelNaive(const std::string &model_dir,
......@@ -545,30 +546,46 @@ void SaveModelNaive(const std::string &model_dir,
bool combined) {
MkDirRecur(model_dir);
// Save program
const std::string prog_path = model_dir + "/__model__.nb";
const std::string prog_path = model_dir + ".nb";
naive_buffer::BinaryTable table;
naive_buffer::proto::ProgramDesc nb_proto_prog(&table);
naive_buffer::ProgramDesc nb_prog(&nb_proto_prog);
TransformProgramDescCppToAny(cpp_prog, &nb_prog);
nb_proto_prog.Save();
table.SaveToFile(prog_path);
// Save meta_version(uint16) into file
naive_buffer::BinaryTable meta_version_table;
meta_version_table.Require(sizeof(uint16_t));
uint16_t meta_version = 0;
memcpy(meta_version_table.cursor(), &meta_version, sizeof(uint16_t));
meta_version_table.Consume(sizeof(uint16_t));
meta_version_table.SaveToFile(prog_path);
// Save lite_version(char[16]) into file
const int paddle_version_length = 16 * sizeof(char);
naive_buffer::BinaryTable paddle_version_table;
paddle_version_table.Require(paddle_version_length);
std::string paddle_version = version();
memcpy(paddle_version_table.cursor(),
paddle_version.c_str(),
paddle_version_length);
paddle_version_table.Consume(paddle_version_length);
paddle_version_table.AppendToFile(prog_path);
VLOG(4) << "paddle_version:" << paddle_version << std::endl;
// Save topology_size(uint64) into file
naive_buffer::BinaryTable topology_size_table;
topology_size_table.Require(sizeof(uint64_t));
uint64_t topology_size = table.size();
memcpy(topology_size_table.cursor(), &topology_size, sizeof(uint64_t));
topology_size_table.Consume(sizeof(uint64_t));
topology_size_table.AppendToFile(prog_path);
// save topology data into model file
table.AppendToFile(prog_path);
// Save Params
// NOTE: Only main block be used now.
if (combined) {
const std::string combined_params_path = model_dir + "/param.nb";
SaveCombinedParamsNaive(combined_params_path, exec_scope, cpp_prog);
} else {
auto prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable())
continue;
const std::string path = model_dir + "/" + var.Name() + ".nb";
SaveParamNaive(path, exec_scope, var.Name());
}
}
SaveCombinedParamsNaive(prog_path, exec_scope, cpp_prog);
LOG(INFO) << "Save naive buffer model in '" << model_dir << "' successfully";
}
#endif
......@@ -638,14 +655,15 @@ void LoadParamNaive(const std::string &path,
}
void LoadCombinedParamsNaive(const std::string &path,
const uint64_t &offset,
lite::Scope *scope,
const cpp::ProgramDesc &cpp_prog,
bool params_from_memory) {
naive_buffer::BinaryTable table;
if (params_from_memory) {
table.LoadFromMemory(path.c_str(), path.length());
table.LoadFromMemory(path.c_str() + offset, path.length() - offset);
} else {
table.LoadFromFile(path);
table.LoadFromFile(path, offset, 0);
}
naive_buffer::proto::CombinedParamsDesc pt_desc(&table);
pt_desc.Load();
......@@ -693,7 +711,7 @@ void LoadModelNaive(const std::string &model_dir,
// NOTE: Only main block be used now.
if (combined) {
const std::string combined_params_path = model_dir + "/param.nb";
LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog, false);
LoadCombinedParamsNaive(combined_params_path, 0, scope, *cpp_prog, false);
} else {
auto &prog = *cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
......@@ -718,6 +736,84 @@ void LoadModelNaive(const std::string &model_dir,
VLOG(4) << "Load naive buffer model in '" << model_dir << "' successfully";
}
/*
* Binary structure of naive_buffer model: model.nb
* ----------------------------------------------------------
* | | PART | Precision | Length(byte) |
* | 1 | meta_version | uint16_t | 2 |
* | 2 | opt_version | char[16] | 16 |
* | 3 | topo_size | uint64_t | 8 |
* | 4 | topo_data | char[] | topo_size byte |
* | 5 | param_data | char[] | |
* ----------------------------------------------------------
* Meaning of each part:
* meta_version: meata_version, 0 default.
* opt_version: lite_version of opt tool that transformed this model.
* topo_size: length of `topo_data`.
* topo_data: contains model's topology data.
* param_data: contains model's params data.
*/
// usage: LoadModelNaiveFromFile is used for loading model from file.
template <typename T>
void ReadModelDataFromFile(T *data,
const std::string &prog_path,
uint64_t *offset,
const uint64_t &size) {
naive_buffer::BinaryTable data_table;
data_table.LoadFromFile(prog_path, *offset, size);
memcpy(data, data_table.cursor(), size);
*offset = *offset + size;
}
void LoadModelNaiveFromFile(const std::string &filename,
Scope *scope,
cpp::ProgramDesc *cpp_prog) {
CHECK(cpp_prog);
CHECK(scope);
cpp_prog->ClearBlocks();
// ModelFile
const std::string prog_path = filename;
// Offset
uint64_t offset = 0;
// (1)get meta version
uint16_t meta_version;
ReadModelDataFromFile<uint16_t>(
&meta_version, prog_path, &offset, sizeof(uint16_t));
VLOG(4) << "Meta_version:" << meta_version;
// (2)get opt version
char opt_version[16];
const uint64_t paddle_version_length = 16 * sizeof(char);
ReadModelDataFromFile<char>(
opt_version, prog_path, &offset, paddle_version_length);
VLOG(4) << "Opt_version:" << opt_version;
// (3)get topo_size
uint64_t topo_size;
ReadModelDataFromFile<uint64_t>(
&topo_size, prog_path, &offset, sizeof(uint64_t));
// (4)get topo data
naive_buffer::BinaryTable topo_table;
topo_table.LoadFromFile(prog_path, offset, topo_size);
offset = offset + topo_size;
// transform topo_data into cpp::ProgramDesc
naive_buffer::proto::ProgramDesc nb_proto_prog(&topo_table);
nb_proto_prog.Load();
naive_buffer::ProgramDesc nb_prog(&nb_proto_prog);
TransformProgramDescAnyToCpp(nb_prog, cpp_prog);
// (5)Load Params
LoadCombinedParamsNaive(prog_path, offset, scope, *cpp_prog, false);
VLOG(4) << "Load naive buffer model in '" << filename << "' successfully";
}
// warning: this is an old inference and is not suggested.
// todo: this inference will be abandened in release/v3.0.0
void LoadModelNaiveFromMemory(const std::string &model_buffer,
const std::string &param_buffer,
Scope *scope,
......@@ -741,7 +837,64 @@ void LoadModelNaiveFromMemory(const std::string &model_buffer,
// Load Params
// NOTE: Only main block be used now.
// only combined Params are supported in Loading Model from memory
LoadCombinedParamsNaive(param_buffer, scope, *cpp_prog, true);
LoadCombinedParamsNaive(param_buffer, 0, scope, *cpp_prog, true);
VLOG(4) << "Load model from naive buffer memory successfully";
}
// usage: LoadModelNaiveFromMemory is used for loading naive model from memory
template <typename T>
void ReadModelDataFromBuffer(T *data,
const std::string &model_buffer,
uint64_t *offset,
const uint64_t &size) {
naive_buffer::BinaryTable data_table;
data_table.LoadFromMemory(model_buffer.c_str() + *offset, size);
memcpy(data, data_table.cursor(), size);
*offset = *offset + size;
}
void LoadModelNaiveFromMemory(const std::string &model_buffer,
Scope *scope,
cpp::ProgramDesc *cpp_prog) {
CHECK(cpp_prog);
CHECK(scope);
cpp_prog->ClearBlocks();
// Offset
uint64_t offset = 0;
// (1)get meta version
uint16_t meta_version;
ReadModelDataFromBuffer<uint16_t>(
&meta_version, model_buffer, &offset, sizeof(uint16_t));
VLOG(4) << "Meta_version:" << meta_version;
// (2)get opt version
char opt_version[16];
const uint64_t paddle_version_length = 16 * sizeof(char);
ReadModelDataFromBuffer<char>(
opt_version, model_buffer, &offset, paddle_version_length);
VLOG(4) << "Opt_version:" << opt_version;
// (3)get topo_size and topo_data
uint64_t topo_size;
ReadModelDataFromBuffer<uint64_t>(
&topo_size, model_buffer, &offset, sizeof(uint64_t));
naive_buffer::BinaryTable table;
table.LoadFromMemory(model_buffer.c_str() + offset, topo_size);
offset = offset + topo_size;
naive_buffer::proto::ProgramDesc nb_proto_prog(&table);
nb_proto_prog.Load();
naive_buffer::ProgramDesc nb_prog(&nb_proto_prog);
// Transform to cpp::ProgramDesc
TransformProgramDescAnyToCpp(nb_prog, cpp_prog);
// Load Params
// NOTE: Only main block be used now.
// only combined Params are supported in Loading Model from memory
LoadCombinedParamsNaive(model_buffer, offset, scope, *cpp_prog, true);
VLOG(4) << "Load model from naive buffer memory successfully";
}
......
......@@ -94,15 +94,22 @@ void LoadParamNaive(const std::string& path,
lite::Scope* scope,
const std::string& name);
// warning:this old inference will be abandened in release/v3.0.0
// and LoadModelNaiveFromFile is suggested.
void LoadModelNaive(const std::string& model_dir,
lite::Scope* scope,
cpp::ProgramDesc* prog,
bool combined = true);
void LoadModelNaiveFromFile(const std::string& filename,
lite::Scope* scope,
cpp::ProgramDesc* prog);
void LoadModelNaiveFromMemory(const std::string& model_buffer,
const std::string& param_buffer,
lite::Scope* scope,
cpp::ProgramDesc* cpp_prog);
void LoadModelNaiveFromMemory(const std::string& model_buffer,
lite::Scope* scope,
cpp::ProgramDesc* cpp_prog);
} // namespace lite
} // namespace paddle
......@@ -121,17 +121,23 @@ TEST(ModelParser, SaveModelNaive) {
SaveModelNaive(save_pb_model_path, scope, prog);
}
TEST(ModelParser, LoadModelNaiveFromFile) {
CHECK(!FLAGS_model_dir.empty());
cpp::ProgramDesc prog;
Scope scope;
auto model_path = std::string(FLAGS_model_dir) + ".saved.naive.nb";
LoadModelNaiveFromFile(model_path, &scope, &prog);
}
TEST(ModelParser, LoadModelNaiveFromMemory) {
CHECK(!FLAGS_model_dir.empty());
cpp::ProgramDesc prog;
Scope scope;
auto model_path = std::string(FLAGS_model_dir) + ".saved.naive/__model__.nb";
auto params_path = std::string(FLAGS_model_dir) + ".saved.naive/param.nb";
auto model_path = std::string(FLAGS_model_dir) + ".saved.naive.nb";
std::string model_buffer = lite::ReadFile(model_path);
std::string params_buffer = lite::ReadFile(params_path);
LoadModelNaiveFromMemory(model_buffer, params_buffer, &scope, &prog);
LoadModelNaiveFromMemory(model_buffer, &scope, &prog);
}
} // namespace lite
......
......@@ -44,24 +44,37 @@ void BinaryTable::SaveToFile(const std::string &filename) const {
fclose(fp);
}
void BinaryTable::LoadFromFile(const std::string &filename) {
// get file size
void BinaryTable::AppendToFile(const std::string &filename) const {
FILE *fp = fopen(filename.c_str(), "ab");
CHECK(fp) << "Unable to open file: " << filename;
if (fwrite(reinterpret_cast<const char *>(data()), 1, size(), fp) != size()) {
fclose(fp);
LOG(FATAL) << "Write file error: " << filename;
}
fclose(fp);
}
void BinaryTable::LoadFromFile(const std::string &filename,
const size_t &offset,
const size_t &size) {
// open file in readonly mode
FILE *fp = fopen(filename.c_str(), "rb");
CHECK(fp) << "Unable to open file: " << filename;
fseek(fp, 0L, SEEK_END);
size_t file_size = ftell(fp);
LOG(INFO) << "file size " << file_size;
// load data.
fseek(fp, 0L, SEEK_SET);
Require(file_size);
if (fread(reinterpret_cast<char *>(&bytes_[0]), 1, file_size, fp) !=
file_size) {
// move fstream pointer backward for size of offset
size_t buffer_size = size;
if (size == 0) {
fseek(fp, 0L, SEEK_END);
buffer_size = ftell(fp) - offset;
}
fseek(fp, offset, SEEK_SET);
Require(buffer_size);
// read data of `size` into binary_data_variable:`bytes_`
if (fread(reinterpret_cast<char *>(&bytes_[0]), 1, buffer_size, fp) !=
buffer_size) {
fclose(fp);
LOG(FATAL) << "Read file error: " << filename;
}
fclose(fp);
// Set readonly.
is_mutable_mode_ = false;
}
......
......@@ -61,8 +61,12 @@ struct BinaryTable {
/// Serialize the table to a binary buffer.
void SaveToFile(const std::string& filename) const;
void AppendToFile(const std::string& filename) const;
void LoadFromFile(const std::string& filename);
// void LoadFromFile(const std::string& filename);
void LoadFromFile(const std::string& filename,
const size_t& offset = 0,
const size_t& size = 0);
void LoadFromMemory(const char* buffer, size_t buffer_size);
};
......
......@@ -14,7 +14,7 @@ readonly NUM_PROC=${LITE_BUILD_THREADS:-4}
# global variables
BUILD_EXTRA=OFF
BUILD_JAVA=ON
BUILD_JAVA=OFF
BUILD_PYTHON=OFF
BUILD_DIR=$(pwd)
OPTMODEL_DIR=""
......@@ -72,7 +72,7 @@ function build_opt {
-DWITH_TESTING=OFF \
-DLITE_BUILD_EXTRA=ON \
-DWITH_MKL=OFF
make OPT -j$NUM_PROC
make opt -j$NUM_PROC
}
function make_tiny_publish_so {
......
......@@ -519,7 +519,7 @@ function test_model_optimize_tool_compile {
cd $workspace
cd build
cmake .. -DWITH_LITE=ON -DLITE_ON_MODEL_OPTIMIZE_TOOL=ON -DWITH_TESTING=OFF -DLITE_BUILD_EXTRA=ON
make OPT -j$NUM_CORES_FOR_COMPILE
make opt -j$NUM_CORES_FOR_COMPILE
}
function _test_paddle_code_generator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册