提交 2109d231 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Add load from memory interface (#1903)

* paddle lite cuda init
can run model with leaky_relu

* add the missing file.
test=develop

* add the load from memory interface.
test=develop

* refine this pr. fix comments
fix ci error
test=develop
上级 ab84a3d5
......@@ -83,14 +83,34 @@ const cpp::ProgramDesc &Predictor::program_desc() const {
}
const RuntimeProgram &Predictor::runtime_program() const { return *program_; }
void Predictor::Build(const lite_api::CxxConfig &config,
const std::vector<Place> &valid_places,
const std::vector<std::string> &passes,
lite_api::LiteModelType model_type) {
const std::string &model_path = config.model_dir();
const std::string &model_file = config.model_file();
const std::string &param_file = config.param_file();
const Place prefer_place = config.preferred_place();
const bool model_from_memory = config.model_from_memory();
LOG(INFO) << "load from memory " << model_from_memory;
Build(model_path,
model_file,
param_file,
prefer_place,
valid_places,
passes,
model_type,
model_from_memory);
}
void Predictor::Build(const std::string &model_path,
const std::string model_file,
const std::string param_file,
const std::string &model_file,
const std::string &param_file,
const Place &prefer_place,
const std::vector<Place> &valid_places,
const std::vector<std::string> &passes,
lite_api::LiteModelType model_type) {
LOG(INFO) << "Load model from " << model_path;
lite_api::LiteModelType model_type,
bool model_from_memory) {
switch (model_type) {
case lite_api::LiteModelType::kProtobuf: {
bool combined_param = false;
......@@ -102,7 +122,8 @@ void Predictor::Build(const std::string &model_path,
param_file,
scope_.get(),
&program_desc_,
combined_param);
combined_param,
model_from_memory);
} break;
case lite_api::LiteModelType::kNaiveBuffer:
CHECK(!model_path.empty())
......
......@@ -39,14 +39,21 @@ class LITE_API Predictor {
: scope_(root_scope) {}
// Build from a model, with places set for hardware config.
void Build(
const lite_api::CxxConfig& config,
const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {},
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf);
void Build(
const std::string& model_path,
const std::string model_file_path,
const std::string param_file_path,
const std::string& model_file_path,
const std::string& param_file_path,
const Place& prefer_place,
const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {},
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf);
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf,
bool memory_from_memory = false);
void Build(const cpp::ProgramDesc& desc,
const Place& prefer_place,
......
......@@ -47,11 +47,7 @@ CxxPaddleApiImpl::CxxPaddleApiImpl() {}
void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
auto places = config.valid_places();
places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny));
raw_predictor_.Build(config.model_dir(),
config.model_file(),
config.param_file(),
config.preferred_place(),
places);
raw_predictor_.Build(config, places);
}
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
......
......@@ -101,17 +101,27 @@ class LITE_API CxxConfig : public ConfigBase {
std::vector<Place> valid_places_;
std::string model_file_;
std::string param_file_;
bool model_from_memory_{false};
public:
void set_preferred_place(const Place& x) { preferred_place_ = x; }
void set_valid_places(const std::vector<Place>& x) { valid_places_ = x; }
void set_model_file(const std::string& path) { model_file_ = path; }
void set_param_file(const std::string& path) { param_file_ = path; }
void set_model_buffer(const char* model_buffer,
size_t model_buffer_size,
const char* param_buffer,
size_t param_buffer_size) {
model_file_ = std::string(model_buffer, model_buffer + model_buffer_size);
param_file_ = std::string(param_buffer, param_buffer + param_buffer_size);
model_from_memory_ = true;
}
const Place& preferred_place() const { return preferred_place_; }
const std::vector<Place>& valid_places() const { return valid_places_; }
std::string model_file() const { return model_file_; }
std::string param_file() const { return param_file_; }
bool model_from_memory() const { return model_from_memory_; }
};
/// MobileConfig is the config for the light weight predictor, it will skip
......
......@@ -157,8 +157,16 @@ struct PMNode {
template <typename T>
PMNode* assert_op_attr(const std::string& attr_name, const T& attr) {
return assert_op_attr_satisfied<T>(
attr_name, [&](const T& src) { return src == attr; });
asserts_.push_back([=](const Node* x) {
if (x && x->IsStmt()) {
auto* op_info = x->stmt()->op_info();
bool cond = (op_info->HasAttr(attr_name) &&
op_info->GetAttr<T>(attr_name) == attr);
return cond;
}
return false;
});
return this;
}
private:
......
......@@ -141,12 +141,16 @@ void ReadBinaryFile(const std::string &filename, std::string *contents) {
}
std::unique_ptr<framework::proto::ProgramDesc> LoadProgram(
const std::string &path) {
std::string desc_str;
ReadBinaryFile(path, &desc_str);
const std::string &path, bool program_from_memory) {
std::unique_ptr<framework::proto::ProgramDesc> main_program(
new framework::proto::ProgramDesc);
main_program->ParseFromString(desc_str);
if (!program_from_memory) {
std::string desc_str;
ReadBinaryFile(path, &desc_str);
main_program->ParseFromString(desc_str);
} else {
main_program->ParseFromString(path);
}
return main_program;
}
......@@ -171,7 +175,8 @@ bool IsPersistable(const cpp::VarDesc &var) {
void LoadCombinedParamsPb(const std::string &path,
lite::Scope *scope,
const cpp::ProgramDesc &cpp_prog) {
const cpp::ProgramDesc &cpp_prog,
bool params_from_memory) {
CHECK(scope);
auto prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
......@@ -186,19 +191,27 @@ void LoadCombinedParamsPb(const std::string &path,
std::sort(paramlist.begin(), paramlist.end());
// Load vars
std::ifstream file(path);
CHECK(file.is_open());
for (size_t i = 0; i < paramlist.size(); ++i) {
auto *var = scope->Var(paramlist[i]);
// Error checking
CHECK(static_cast<bool>(file))
<< "There is a problem with loading model parameters";
LoadLoDTensor(file, var);
}
file.peek();
CHECK(file.eof()) << "You are not allowed to load partial data via"
auto load_var_func = [&](std::istream &is) {
for (size_t i = 0; i < paramlist.size(); ++i) {
auto *var = scope->Var(paramlist[i]);
// Error checking
CHECK(static_cast<bool>(is))
<< "There is a problem with loading model parameters";
LoadLoDTensor(is, var);
}
is.peek();
CHECK(is.eof()) << "You are not allowed to load partial data via"
<< " LoadCombinedParamsPb, use LoadParam instead.";
file.close();
};
if (params_from_memory) {
std::stringstream fin(path, std::ios::in | std::ios::binary);
load_var_func(fin);
} else {
std::ifstream fin(path, std::ios::binary);
CHECK(fin.is_open());
load_var_func(fin);
}
}
void LoadModelPb(const std::string &model_dir,
......@@ -206,26 +219,33 @@ void LoadModelPb(const std::string &model_dir,
const std::string &param_file,
Scope *scope,
cpp::ProgramDesc *cpp_prog,
bool combined) {
bool combined,
bool model_from_memory) {
CHECK(cpp_prog);
CHECK(scope);
cpp_prog->ClearBlocks();
// Load model
VLOG(4) << "Start load model program...";
std::string prog_path = model_dir + "/__model__";
if (combined) {
prog_path = model_file;
}
framework::proto::ProgramDesc pb_proto_prog = *LoadProgram(prog_path);
framework::proto::ProgramDesc pb_proto_prog =
*LoadProgram(prog_path, model_from_memory);
pb::ProgramDesc pb_prog(&pb_proto_prog);
// Transform to cpp::ProgramDesc
TransformProgramDescAnyToCpp(pb_prog, cpp_prog);
// Load Params
// NOTE: Only main block be used now.
VLOG(4) << "Start load model params...";
CHECK(!(!combined && model_from_memory))
<< "If you want use the model_from_memory,"
<< " you should load the combined model using cfg.set_model_buffer "
"interface.";
if (combined) {
LoadCombinedParamsPb(param_file, scope, *cpp_prog);
LoadCombinedParamsPb(param_file, scope, *cpp_prog, model_from_memory);
} else {
auto main_block = pb_proto_prog.blocks(0);
for (auto &var : main_block.vars()) {
......
......@@ -33,7 +33,7 @@ namespace lite {
#ifndef LITE_ON_TINY_PUBLISH
// Read a __model__ file.
std::unique_ptr<framework::proto::ProgramDesc> LoadProgram(
const std::string& path);
const std::string& path, bool program_from_memory = false);
// Read a single file containing all the parameters.
void LoadParams(const std::string& path);
......@@ -43,7 +43,8 @@ void LoadParam(const std::string& path, Variable* out);
void LoadCombinedParamsPb(const std::string& path,
lite::Scope* scope,
const cpp::ProgramDesc& prog);
const cpp::ProgramDesc& prog,
bool params_from_memory = false);
// Read a model and files of parameters in pb format.
void LoadModelPb(const std::string& model_dir,
......@@ -51,7 +52,8 @@ void LoadModelPb(const std::string& model_dir,
const std::string& param_file,
Scope* scope,
cpp::ProgramDesc* prog,
bool combined = false);
bool combined = false,
bool model_from_memory = false);
// Save a model and files of parameters in pb format.
void SaveModelPb(const std::string& model_dir,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册