diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 5f160b6f793392c43c018e2d613aedc1a7536d4d..ccf36e31912196a5c5fb979980ccf966f2df4be9 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -83,14 +83,34 @@ const cpp::ProgramDesc &Predictor::program_desc() const { } const RuntimeProgram &Predictor::runtime_program() const { return *program_; } +void Predictor::Build(const lite_api::CxxConfig &config, + const std::vector &valid_places, + const std::vector &passes, + lite_api::LiteModelType model_type) { + const std::string &model_path = config.model_dir(); + const std::string &model_file = config.model_file(); + const std::string ¶m_file = config.param_file(); + const Place prefer_place = config.preferred_place(); + const bool model_from_memory = config.model_from_memory(); + LOG(INFO) << "load from memory " << model_from_memory; + + Build(model_path, + model_file, + param_file, + prefer_place, + valid_places, + passes, + model_type, + model_from_memory); +} void Predictor::Build(const std::string &model_path, - const std::string model_file, - const std::string param_file, + const std::string &model_file, + const std::string ¶m_file, const Place &prefer_place, const std::vector &valid_places, const std::vector &passes, - lite_api::LiteModelType model_type) { - LOG(INFO) << "Load model from " << model_path; + lite_api::LiteModelType model_type, + bool model_from_memory) { switch (model_type) { case lite_api::LiteModelType::kProtobuf: { bool combined_param = false; @@ -102,7 +122,8 @@ void Predictor::Build(const std::string &model_path, param_file, scope_.get(), &program_desc_, - combined_param); + combined_param, + model_from_memory); } break; case lite_api::LiteModelType::kNaiveBuffer: CHECK(!model_path.empty()) diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index c12f6996f42c3c134a3bd86c03bac90c558e97e8..2506ae47b0ddbce683d8f4b12e000bb3ea19d497 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -39,14 +39,21 @@ class LITE_API Predictor { : scope_(root_scope) {} // Build from a model, with places set for hardware config. + void Build( + const lite_api::CxxConfig& config, + const std::vector& valid_places, + const std::vector& passes = {}, + lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf); + void Build( const std::string& model_path, - const std::string model_file_path, - const std::string param_file_path, + const std::string& model_file_path, + const std::string& param_file_path, const Place& prefer_place, const std::vector& valid_places, const std::vector& passes = {}, - lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf); + lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf, + bool memory_from_memory = false); void Build(const cpp::ProgramDesc& desc, const Place& prefer_place, diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index 3a1a52e2afda6aeb334056a5c6fa3f34a76d7737..b8c92a8f96afefa7a2de6b844980f9c0f769f6a9 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -47,11 +47,7 @@ CxxPaddleApiImpl::CxxPaddleApiImpl() {} void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { auto places = config.valid_places(); places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)); - raw_predictor_.Build(config.model_dir(), - config.model_file(), - config.param_file(), - config.preferred_place(), - places); + raw_predictor_.Build(config, places); } std::unique_ptr CxxPaddleApiImpl::GetInput(int i) { diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 4876aa7fb1fc5e187ec42022e65ce0dcccd2fac7..237122474b1b594d031f8e51ee7c9af6f00bcc21 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -101,17 +101,27 @@ class LITE_API CxxConfig : public ConfigBase { std::vector valid_places_; std::string model_file_; std::string param_file_; + bool model_from_memory_{false}; public: void set_preferred_place(const Place& x) { preferred_place_ = x; } void set_valid_places(const std::vector& x) { valid_places_ = x; } void set_model_file(const std::string& path) { model_file_ = path; } void set_param_file(const std::string& path) { param_file_ = path; } + void set_model_buffer(const char* model_buffer, + size_t model_buffer_size, + const char* param_buffer, + size_t param_buffer_size) { + model_file_ = std::string(model_buffer, model_buffer + model_buffer_size); + param_file_ = std::string(param_buffer, param_buffer + param_buffer_size); + model_from_memory_ = true; + } const Place& preferred_place() const { return preferred_place_; } const std::vector& valid_places() const { return valid_places_; } std::string model_file() const { return model_file_; } std::string param_file() const { return param_file_; } + bool model_from_memory() const { return model_from_memory_; } }; /// MobileConfig is the config for the light weight predictor, it will skip diff --git a/lite/core/mir/pattern_matcher.h b/lite/core/mir/pattern_matcher.h index 7bcb5c07a4bcd65ae9190f30be0c05de31a92e41..42d1b3fe555a97ba4168e205217867e35c4b0894 100644 --- a/lite/core/mir/pattern_matcher.h +++ b/lite/core/mir/pattern_matcher.h @@ -157,8 +157,16 @@ struct PMNode { template PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { - return assert_op_attr_satisfied( - attr_name, [&](const T& src) { return src == attr; }); + asserts_.push_back([=](const Node* x) { + if (x && x->IsStmt()) { + auto* op_info = x->stmt()->op_info(); + bool cond = (op_info->HasAttr(attr_name) && + op_info->GetAttr(attr_name) == attr); + return cond; + } + return false; + }); + return this; } private: diff --git a/lite/model_parser/model_parser.cc b/lite/model_parser/model_parser.cc index ed8f5a96f0d4a3d66d1e83d80081c57c4ad95288..398f3b3d0a9c0ea65d80101cabb0036ad722701b 100644 --- a/lite/model_parser/model_parser.cc +++ b/lite/model_parser/model_parser.cc @@ -141,12 +141,16 @@ void ReadBinaryFile(const std::string &filename, std::string *contents) { } std::unique_ptr LoadProgram( - const std::string &path) { - std::string desc_str; - ReadBinaryFile(path, &desc_str); + const std::string &path, bool program_from_memory) { std::unique_ptr main_program( new framework::proto::ProgramDesc); - main_program->ParseFromString(desc_str); + if (!program_from_memory) { + std::string desc_str; + ReadBinaryFile(path, &desc_str); + main_program->ParseFromString(desc_str); + } else { + main_program->ParseFromString(path); + } return main_program; } @@ -171,7 +175,8 @@ bool IsPersistable(const cpp::VarDesc &var) { void LoadCombinedParamsPb(const std::string &path, lite::Scope *scope, - const cpp::ProgramDesc &cpp_prog) { + const cpp::ProgramDesc &cpp_prog, + bool params_from_memory) { CHECK(scope); auto prog = cpp_prog; auto &main_block_desc = *prog.GetBlock(0); @@ -186,19 +191,27 @@ void LoadCombinedParamsPb(const std::string &path, std::sort(paramlist.begin(), paramlist.end()); // Load vars - std::ifstream file(path); - CHECK(file.is_open()); - for (size_t i = 0; i < paramlist.size(); ++i) { - auto *var = scope->Var(paramlist[i]); - // Error checking - CHECK(static_cast(file)) - << "There is a problem with loading model parameters"; - LoadLoDTensor(file, var); - } - file.peek(); - CHECK(file.eof()) << "You are not allowed to load partial data via" + auto load_var_func = [&](std::istream &is) { + for (size_t i = 0; i < paramlist.size(); ++i) { + auto *var = scope->Var(paramlist[i]); + // Error checking + CHECK(static_cast(is)) + << "There is a problem with loading model parameters"; + LoadLoDTensor(is, var); + } + is.peek(); + CHECK(is.eof()) << "You are not allowed to load partial data via" << " LoadCombinedParamsPb, use LoadParam instead."; - file.close(); + }; + + if (params_from_memory) { + std::stringstream fin(path, std::ios::in | std::ios::binary); + load_var_func(fin); + } else { + std::ifstream fin(path, std::ios::binary); + CHECK(fin.is_open()); + load_var_func(fin); + } } void LoadModelPb(const std::string &model_dir, @@ -206,26 +219,33 @@ void LoadModelPb(const std::string &model_dir, const std::string ¶m_file, Scope *scope, cpp::ProgramDesc *cpp_prog, - bool combined) { + bool combined, + bool model_from_memory) { CHECK(cpp_prog); CHECK(scope); cpp_prog->ClearBlocks(); // Load model + VLOG(4) << "Start load model program..."; std::string prog_path = model_dir + "/__model__"; if (combined) { prog_path = model_file; } - framework::proto::ProgramDesc pb_proto_prog = *LoadProgram(prog_path); + framework::proto::ProgramDesc pb_proto_prog = + *LoadProgram(prog_path, model_from_memory); pb::ProgramDesc pb_prog(&pb_proto_prog); - // Transform to cpp::ProgramDesc TransformProgramDescAnyToCpp(pb_prog, cpp_prog); // Load Params // NOTE: Only main block be used now. + VLOG(4) << "Start load model params..."; + CHECK(!(!combined && model_from_memory)) + << "If you want use the model_from_memory," + << " you should load the combined model using cfg.set_model_buffer " + "interface."; if (combined) { - LoadCombinedParamsPb(param_file, scope, *cpp_prog); + LoadCombinedParamsPb(param_file, scope, *cpp_prog, model_from_memory); } else { auto main_block = pb_proto_prog.blocks(0); for (auto &var : main_block.vars()) { diff --git a/lite/model_parser/model_parser.h b/lite/model_parser/model_parser.h index 74709c044d32c672ca8c643e0350804eb9a7a2a5..5592ff9cf70c8836b6a24945215aabe4a98efcce 100644 --- a/lite/model_parser/model_parser.h +++ b/lite/model_parser/model_parser.h @@ -33,7 +33,7 @@ namespace lite { #ifndef LITE_ON_TINY_PUBLISH // Read a __model__ file. std::unique_ptr LoadProgram( - const std::string& path); + const std::string& path, bool program_from_memory = false); // Read a single file containing all the parameters. void LoadParams(const std::string& path); @@ -43,7 +43,8 @@ void LoadParam(const std::string& path, Variable* out); void LoadCombinedParamsPb(const std::string& path, lite::Scope* scope, - const cpp::ProgramDesc& prog); + const cpp::ProgramDesc& prog, + bool params_from_memory = false); // Read a model and files of parameters in pb format. void LoadModelPb(const std::string& model_dir, @@ -51,7 +52,8 @@ void LoadModelPb(const std::string& model_dir, const std::string& param_file, Scope* scope, cpp::ProgramDesc* prog, - bool combined = false); + bool combined = false, + bool model_from_memory = false); // Save a model and files of parameters in pb format. void SaveModelPb(const std::string& model_dir,