提交 2109d231 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Add load from memory interface (#1903)

* paddle lite cuda init
can run model with leaky_relu

* add the missing file.
test=develop

* add the load from memory interface.
test=develop

* refine this pr. fix comments
fix ci error
test=develop
上级 ab84a3d5
...@@ -83,14 +83,34 @@ const cpp::ProgramDesc &Predictor::program_desc() const { ...@@ -83,14 +83,34 @@ const cpp::ProgramDesc &Predictor::program_desc() const {
} }
const RuntimeProgram &Predictor::runtime_program() const { return *program_; } const RuntimeProgram &Predictor::runtime_program() const { return *program_; }
void Predictor::Build(const lite_api::CxxConfig &config,
const std::vector<Place> &valid_places,
const std::vector<std::string> &passes,
lite_api::LiteModelType model_type) {
const std::string &model_path = config.model_dir();
const std::string &model_file = config.model_file();
const std::string &param_file = config.param_file();
const Place prefer_place = config.preferred_place();
const bool model_from_memory = config.model_from_memory();
LOG(INFO) << "load from memory " << model_from_memory;
Build(model_path,
model_file,
param_file,
prefer_place,
valid_places,
passes,
model_type,
model_from_memory);
}
void Predictor::Build(const std::string &model_path, void Predictor::Build(const std::string &model_path,
const std::string model_file, const std::string &model_file,
const std::string param_file, const std::string &param_file,
const Place &prefer_place, const Place &prefer_place,
const std::vector<Place> &valid_places, const std::vector<Place> &valid_places,
const std::vector<std::string> &passes, const std::vector<std::string> &passes,
lite_api::LiteModelType model_type) { lite_api::LiteModelType model_type,
LOG(INFO) << "Load model from " << model_path; bool model_from_memory) {
switch (model_type) { switch (model_type) {
case lite_api::LiteModelType::kProtobuf: { case lite_api::LiteModelType::kProtobuf: {
bool combined_param = false; bool combined_param = false;
...@@ -102,7 +122,8 @@ void Predictor::Build(const std::string &model_path, ...@@ -102,7 +122,8 @@ void Predictor::Build(const std::string &model_path,
param_file, param_file,
scope_.get(), scope_.get(),
&program_desc_, &program_desc_,
combined_param); combined_param,
model_from_memory);
} break; } break;
case lite_api::LiteModelType::kNaiveBuffer: case lite_api::LiteModelType::kNaiveBuffer:
CHECK(!model_path.empty()) CHECK(!model_path.empty())
......
...@@ -39,14 +39,21 @@ class LITE_API Predictor { ...@@ -39,14 +39,21 @@ class LITE_API Predictor {
: scope_(root_scope) {} : scope_(root_scope) {}
// Build from a model, with places set for hardware config. // Build from a model, with places set for hardware config.
void Build(
const lite_api::CxxConfig& config,
const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {},
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf);
void Build( void Build(
const std::string& model_path, const std::string& model_path,
const std::string model_file_path, const std::string& model_file_path,
const std::string param_file_path, const std::string& param_file_path,
const Place& prefer_place, const Place& prefer_place,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {}, const std::vector<std::string>& passes = {},
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf); lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf,
bool memory_from_memory = false);
void Build(const cpp::ProgramDesc& desc, void Build(const cpp::ProgramDesc& desc,
const Place& prefer_place, const Place& prefer_place,
......
...@@ -47,11 +47,7 @@ CxxPaddleApiImpl::CxxPaddleApiImpl() {} ...@@ -47,11 +47,7 @@ CxxPaddleApiImpl::CxxPaddleApiImpl() {}
void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
auto places = config.valid_places(); auto places = config.valid_places();
places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)); places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny));
raw_predictor_.Build(config.model_dir(), raw_predictor_.Build(config, places);
config.model_file(),
config.param_file(),
config.preferred_place(),
places);
} }
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) { std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
......
...@@ -101,17 +101,27 @@ class LITE_API CxxConfig : public ConfigBase { ...@@ -101,17 +101,27 @@ class LITE_API CxxConfig : public ConfigBase {
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
std::string model_file_; std::string model_file_;
std::string param_file_; std::string param_file_;
bool model_from_memory_{false};
public: public:
void set_preferred_place(const Place& x) { preferred_place_ = x; } void set_preferred_place(const Place& x) { preferred_place_ = x; }
void set_valid_places(const std::vector<Place>& x) { valid_places_ = x; } void set_valid_places(const std::vector<Place>& x) { valid_places_ = x; }
void set_model_file(const std::string& path) { model_file_ = path; } void set_model_file(const std::string& path) { model_file_ = path; }
void set_param_file(const std::string& path) { param_file_ = path; } void set_param_file(const std::string& path) { param_file_ = path; }
void set_model_buffer(const char* model_buffer,
size_t model_buffer_size,
const char* param_buffer,
size_t param_buffer_size) {
model_file_ = std::string(model_buffer, model_buffer + model_buffer_size);
param_file_ = std::string(param_buffer, param_buffer + param_buffer_size);
model_from_memory_ = true;
}
const Place& preferred_place() const { return preferred_place_; } const Place& preferred_place() const { return preferred_place_; }
const std::vector<Place>& valid_places() const { return valid_places_; } const std::vector<Place>& valid_places() const { return valid_places_; }
std::string model_file() const { return model_file_; } std::string model_file() const { return model_file_; }
std::string param_file() const { return param_file_; } std::string param_file() const { return param_file_; }
bool model_from_memory() const { return model_from_memory_; }
}; };
/// MobileConfig is the config for the light weight predictor, it will skip /// MobileConfig is the config for the light weight predictor, it will skip
......
...@@ -157,8 +157,16 @@ struct PMNode { ...@@ -157,8 +157,16 @@ struct PMNode {
template <typename T> template <typename T>
PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { PMNode* assert_op_attr(const std::string& attr_name, const T& attr) {
return assert_op_attr_satisfied<T>( asserts_.push_back([=](const Node* x) {
attr_name, [&](const T& src) { return src == attr; }); if (x && x->IsStmt()) {
auto* op_info = x->stmt()->op_info();
bool cond = (op_info->HasAttr(attr_name) &&
op_info->GetAttr<T>(attr_name) == attr);
return cond;
}
return false;
});
return this;
} }
private: private:
......
...@@ -141,12 +141,16 @@ void ReadBinaryFile(const std::string &filename, std::string *contents) { ...@@ -141,12 +141,16 @@ void ReadBinaryFile(const std::string &filename, std::string *contents) {
} }
std::unique_ptr<framework::proto::ProgramDesc> LoadProgram( std::unique_ptr<framework::proto::ProgramDesc> LoadProgram(
const std::string &path) { const std::string &path, bool program_from_memory) {
std::string desc_str;
ReadBinaryFile(path, &desc_str);
std::unique_ptr<framework::proto::ProgramDesc> main_program( std::unique_ptr<framework::proto::ProgramDesc> main_program(
new framework::proto::ProgramDesc); new framework::proto::ProgramDesc);
main_program->ParseFromString(desc_str); if (!program_from_memory) {
std::string desc_str;
ReadBinaryFile(path, &desc_str);
main_program->ParseFromString(desc_str);
} else {
main_program->ParseFromString(path);
}
return main_program; return main_program;
} }
...@@ -171,7 +175,8 @@ bool IsPersistable(const cpp::VarDesc &var) { ...@@ -171,7 +175,8 @@ bool IsPersistable(const cpp::VarDesc &var) {
void LoadCombinedParamsPb(const std::string &path, void LoadCombinedParamsPb(const std::string &path,
lite::Scope *scope, lite::Scope *scope,
const cpp::ProgramDesc &cpp_prog) { const cpp::ProgramDesc &cpp_prog,
bool params_from_memory) {
CHECK(scope); CHECK(scope);
auto prog = cpp_prog; auto prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0); auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
...@@ -186,19 +191,27 @@ void LoadCombinedParamsPb(const std::string &path, ...@@ -186,19 +191,27 @@ void LoadCombinedParamsPb(const std::string &path,
std::sort(paramlist.begin(), paramlist.end()); std::sort(paramlist.begin(), paramlist.end());
// Load vars // Load vars
std::ifstream file(path); auto load_var_func = [&](std::istream &is) {
CHECK(file.is_open()); for (size_t i = 0; i < paramlist.size(); ++i) {
for (size_t i = 0; i < paramlist.size(); ++i) { auto *var = scope->Var(paramlist[i]);
auto *var = scope->Var(paramlist[i]); // Error checking
// Error checking CHECK(static_cast<bool>(is))
CHECK(static_cast<bool>(file)) << "There is a problem with loading model parameters";
<< "There is a problem with loading model parameters"; LoadLoDTensor(is, var);
LoadLoDTensor(file, var); }
} is.peek();
file.peek(); CHECK(is.eof()) << "You are not allowed to load partial data via"
CHECK(file.eof()) << "You are not allowed to load partial data via"
<< " LoadCombinedParamsPb, use LoadParam instead."; << " LoadCombinedParamsPb, use LoadParam instead.";
file.close(); };
if (params_from_memory) {
std::stringstream fin(path, std::ios::in | std::ios::binary);
load_var_func(fin);
} else {
std::ifstream fin(path, std::ios::binary);
CHECK(fin.is_open());
load_var_func(fin);
}
} }
void LoadModelPb(const std::string &model_dir, void LoadModelPb(const std::string &model_dir,
...@@ -206,26 +219,33 @@ void LoadModelPb(const std::string &model_dir, ...@@ -206,26 +219,33 @@ void LoadModelPb(const std::string &model_dir,
const std::string &param_file, const std::string &param_file,
Scope *scope, Scope *scope,
cpp::ProgramDesc *cpp_prog, cpp::ProgramDesc *cpp_prog,
bool combined) { bool combined,
bool model_from_memory) {
CHECK(cpp_prog); CHECK(cpp_prog);
CHECK(scope); CHECK(scope);
cpp_prog->ClearBlocks(); cpp_prog->ClearBlocks();
// Load model // Load model
VLOG(4) << "Start load model program...";
std::string prog_path = model_dir + "/__model__"; std::string prog_path = model_dir + "/__model__";
if (combined) { if (combined) {
prog_path = model_file; prog_path = model_file;
} }
framework::proto::ProgramDesc pb_proto_prog = *LoadProgram(prog_path); framework::proto::ProgramDesc pb_proto_prog =
*LoadProgram(prog_path, model_from_memory);
pb::ProgramDesc pb_prog(&pb_proto_prog); pb::ProgramDesc pb_prog(&pb_proto_prog);
// Transform to cpp::ProgramDesc // Transform to cpp::ProgramDesc
TransformProgramDescAnyToCpp(pb_prog, cpp_prog); TransformProgramDescAnyToCpp(pb_prog, cpp_prog);
// Load Params // Load Params
// NOTE: Only main block be used now. // NOTE: Only main block be used now.
VLOG(4) << "Start load model params...";
CHECK(!(!combined && model_from_memory))
<< "If you want use the model_from_memory,"
<< " you should load the combined model using cfg.set_model_buffer "
"interface.";
if (combined) { if (combined) {
LoadCombinedParamsPb(param_file, scope, *cpp_prog); LoadCombinedParamsPb(param_file, scope, *cpp_prog, model_from_memory);
} else { } else {
auto main_block = pb_proto_prog.blocks(0); auto main_block = pb_proto_prog.blocks(0);
for (auto &var : main_block.vars()) { for (auto &var : main_block.vars()) {
......
...@@ -33,7 +33,7 @@ namespace lite { ...@@ -33,7 +33,7 @@ namespace lite {
#ifndef LITE_ON_TINY_PUBLISH #ifndef LITE_ON_TINY_PUBLISH
// Read a __model__ file. // Read a __model__ file.
std::unique_ptr<framework::proto::ProgramDesc> LoadProgram( std::unique_ptr<framework::proto::ProgramDesc> LoadProgram(
const std::string& path); const std::string& path, bool program_from_memory = false);
// Read a single file containing all the parameters. // Read a single file containing all the parameters.
void LoadParams(const std::string& path); void LoadParams(const std::string& path);
...@@ -43,7 +43,8 @@ void LoadParam(const std::string& path, Variable* out); ...@@ -43,7 +43,8 @@ void LoadParam(const std::string& path, Variable* out);
void LoadCombinedParamsPb(const std::string& path, void LoadCombinedParamsPb(const std::string& path,
lite::Scope* scope, lite::Scope* scope,
const cpp::ProgramDesc& prog); const cpp::ProgramDesc& prog,
bool params_from_memory = false);
// Read a model and files of parameters in pb format. // Read a model and files of parameters in pb format.
void LoadModelPb(const std::string& model_dir, void LoadModelPb(const std::string& model_dir,
...@@ -51,7 +52,8 @@ void LoadModelPb(const std::string& model_dir, ...@@ -51,7 +52,8 @@ void LoadModelPb(const std::string& model_dir,
const std::string& param_file, const std::string& param_file,
Scope* scope, Scope* scope,
cpp::ProgramDesc* prog, cpp::ProgramDesc* prog,
bool combined = false); bool combined = false,
bool model_from_memory = false);
// Save a model and files of parameters in pb format. // Save a model and files of parameters in pb format.
void SaveModelPb(const std::string& model_dir, void SaveModelPb(const std::string& model_dir,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册