diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc index 8505587a0f819e62a0ba1c5cc969b06ca30499b8..98b79e58aa349436ad64dcc0a54256d5d9ead3df 100644 --- a/lite/api/light_api.cc +++ b/lite/api/light_api.cc @@ -18,18 +18,26 @@ namespace paddle { namespace lite { void LightPredictor::Build(const std::string& model_dir, - lite_api::LiteModelType model_type) { + const std::string& model_buffer, + const std::string& param_buffer, + lite_api::LiteModelType model_type, + bool model_from_memory) { cpp::ProgramDesc desc; - LOG(INFO) << "Load model from " << model_dir; switch (model_type) { #ifndef LITE_ON_TINY_PUBLISH case lite_api::LiteModelType::kProtobuf: LoadModelPb(model_dir, "", "", scope_.get(), &desc); break; #endif - case lite_api::LiteModelType::kNaiveBuffer: - LoadModelNaive(model_dir, scope_.get(), &desc); + case lite_api::LiteModelType::kNaiveBuffer: { + if (model_from_memory) { + LoadModelNaiveFromMemory( + model_buffer, param_buffer, scope_.get(), &desc); + } else { + LoadModelNaive(model_dir, scope_.get(), &desc); + } break; + } default: LOG(FATAL) << "Unknown model type"; } @@ -83,11 +91,5 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { program_->set_exec_scope(program.exec_scope()); } -LightPredictor::LightPredictor(const std::string& model_dir, - lite_api::LiteModelType model_type) { - scope_ = std::make_shared(); - Build(model_dir, model_type); -} - } // namespace lite } // namespace paddle diff --git a/lite/api/light_api.h b/lite/api/light_api.h index 6d3a0bcebbce299c0d130193baa11ecffafb733b..241540174489d40c0688cba2ce7911f11b5b5832 100644 --- a/lite/api/light_api.h +++ b/lite/api/light_api.h @@ -38,9 +38,15 @@ namespace lite { */ class LITE_API LightPredictor { public: - explicit LightPredictor( + LightPredictor( const std::string& model_dir, - lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf); + const std::string& model_buffer = "", + const std::string& param_buffer = "", + bool model_from_memory = false, + lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf) { + scope_ = std::make_shared(); + Build(model_dir, model_buffer, param_buffer, model_type, model_from_memory); + } void Run() { program_->Run(); } @@ -58,7 +64,11 @@ class LITE_API LightPredictor { private: void Build( const std::string& model_dir, - lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf); + const std::string& model_buffer, + const std::string& param_buffer, + lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf, + bool model_from_memory = false); + void BuildRuntimeProgram(const cpp::ProgramDesc& prog); private: diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index fd42ade4c7579b9c204fb41926636182727ec4dc..6075f1a36f6803b7e5090697802dcb47fafa0d0d 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -45,6 +45,9 @@ void LightPredictorImpl::Init(const MobileConfig& config) { lite::DeviceInfo::Global().SetRunMode(config.power_mode(), config.threads()); #endif raw_predictor_.reset(new lite::LightPredictor(config.model_dir(), + config.model_buffer(), + config.param_buffer(), + config.model_from_memory(), LiteModelType::kNaiveBuffer)); } diff --git a/lite/api/light_api_test.cc b/lite/api/light_api_test.cc index 6f565b518b594fd5090ff60690549a1ad962f1d6..8e2fc420bc3be91e35047b823e628b80f2175496 100644 --- a/lite/api/light_api_test.cc +++ b/lite/api/light_api_test.cc @@ -28,7 +28,46 @@ TEST(LightAPI, load) { if (FLAGS_optimized_model.empty()) { FLAGS_optimized_model = "lite_naive_model"; } - LightPredictor predictor(FLAGS_optimized_model); + LightPredictor predictor(FLAGS_optimized_model, "", ""); + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({100, 100}))); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100 * 100; i++) { + data[i] = i; + } + + predictor.Run(); + + const auto* output = predictor.GetOutput(0); + const float* raw_output = output->data(); + + for (int i = 0; i < 10; i++) { + LOG(INFO) << "out " << raw_output[i]; + } +} + +TEST(LightAPI, loadNaiveBuffer) { + if (FLAGS_optimized_model.empty()) { + FLAGS_optimized_model = "lite_naive_model"; + } + + auto model_path = std::string(FLAGS_optimized_model) + "/__model__.nb"; + auto params_path = std::string(FLAGS_optimized_model) + "/param.nb"; + std::string model_buffer = lite::ReadFile(model_path); + size_t size_model = model_buffer.length(); + std::string params_buffer = lite::ReadFile(params_path); + size_t size_params = params_buffer.length(); + LOG(INFO) << "sizeModel: " << size_model; + LOG(INFO) << "sizeParams: " << size_params; + + lite_api::MobileConfig config; + config.set_model_buffer( + model_buffer.c_str(), size_model, params_buffer.c_str(), size_params); + LightPredictor predictor(config.model_dir(), + config.model_buffer(), + config.param_buffer(), + config.model_from_memory(), + lite_api::LiteModelType::kNaiveBuffer); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({100, 100}))); diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 237122474b1b594d031f8e51ee7c9af6f00bcc21..b1a8b21935bfbab603c7f27e233cc6115414dc7e 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -129,6 +129,9 @@ class LITE_API CxxConfig : public ConfigBase { class LITE_API MobileConfig : public ConfigBase { PowerMode mode_{LITE_POWER_HIGH}; int threads_{1}; + std::string model_buffer_; + std::string param_buffer_; + bool model_from_memory_{false}; public: MobileConfig(Place preferred_place = Place(TARGET(kARM), @@ -139,9 +142,20 @@ class LITE_API MobileConfig : public ConfigBase { : mode_(mode), threads_(threads) {} void set_power_mode(PowerMode mode) { mode_ = mode; } void set_threads(int threads) { threads_ = threads; } + void set_model_buffer(const char* model_buffer, + size_t model_buffer_size, + const char* param_buffer, + size_t param_buffer_size) { + model_buffer_ = std::string(model_buffer, model_buffer + model_buffer_size); + param_buffer_ = std::string(param_buffer, param_buffer + param_buffer_size); + model_from_memory_ = true; + } PowerMode power_mode() const { return mode_; } int threads() const { return threads_; } + bool model_from_memory() const { return model_from_memory_; } + const std::string& model_buffer() const { return model_buffer_; } + const std::string& param_buffer() const { return param_buffer_; } }; template diff --git a/lite/api/paddle_api_test.cc b/lite/api/paddle_api_test.cc index cc1523f185b85bae617492b5671aedb07b0ae979..02502ff9c80f3ee3c5a23f8ef6909353d839ea9e 100644 --- a/lite/api/paddle_api_test.cc +++ b/lite/api/paddle_api_test.cc @@ -19,7 +19,7 @@ #include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_passes.h" #include "lite/utils/cp_logging.h" - +#include "lite/utils/io.h" DEFINE_string(model_dir, "", ""); namespace paddle { @@ -58,6 +58,7 @@ TEST(CxxApi, run) { LiteModelType::kNaiveBuffer); } +// Demo1 for Mobile Devices :Load model from file and run #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK TEST(LightApi, run) { lite_api::MobileConfig config; @@ -82,6 +83,39 @@ TEST(LightApi, run) { EXPECT_NEAR(out[0], 50.2132, 1e-3); EXPECT_NEAR(out[1], -28.8729, 1e-3); } + +// Demo2 for Loading model from memory +TEST(MobileConfig, LoadfromMemory) { + // Get naive buffer + auto model_path = std::string(FLAGS_model_dir) + ".opt2.naive/__model__.nb"; + auto params_path = std::string(FLAGS_model_dir) + ".opt2.naive/param.nb"; + std::string model_buffer = lite::ReadFile(model_path); + size_t size_model = model_buffer.length(); + std::string params_buffer = lite::ReadFile(params_path); + size_t size_params = params_buffer.length(); + // set model buffer and run model + lite_api::MobileConfig config; + config.set_model_buffer( + model_buffer.c_str(), size_model, params_buffer.c_str(), size_params); + + auto predictor = lite_api::CreatePaddlePredictor(config); + auto input_tensor = predictor->GetInput(0); + input_tensor->Resize(std::vector({100, 100})); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100 * 100; i++) { + data[i] = i; + } + + predictor->Run(); + + const auto output = predictor->GetOutput(0); + const float* raw_output = output->data(); + + for (int i = 0; i < 10; i++) { + LOG(INFO) << "out " << raw_output[i]; + } +} + #endif } // namespace lite_api diff --git a/lite/model_parser/model_parser.cc b/lite/model_parser/model_parser.cc index 95da6d450602f5db8069b3421f25b4c080475209..cc7c22ea6f87d510041320cc470ef1c9040327ba 100644 --- a/lite/model_parser/model_parser.cc +++ b/lite/model_parser/model_parser.cc @@ -661,9 +661,14 @@ void LoadParamNaive(const std::string &path, void LoadCombinedParamsNaive(const std::string &path, lite::Scope *scope, - const cpp::ProgramDesc &cpp_prog) { + const cpp::ProgramDesc &cpp_prog, + bool params_from_memory) { naive_buffer::BinaryTable table; - table.LoadFromFile(path); + if (params_from_memory) { + table.LoadFromMemory(path.c_str(), path.length()); + } else { + table.LoadFromFile(path); + } naive_buffer::proto::CombinedParamsDesc pt_desc(&table); pt_desc.Load(); naive_buffer::CombinedParamsDesc desc(&pt_desc); @@ -710,7 +715,7 @@ void LoadModelNaive(const std::string &model_dir, // NOTE: Only main block be used now. if (combined) { const std::string combined_params_path = model_dir + "/param.nb"; - LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog); + LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog, false); } else { auto &prog = *cpp_prog; auto &main_block_desc = *prog.GetBlock(0); @@ -750,5 +755,40 @@ void LoadModelNaive(const std::string &model_dir, VLOG(4) << "Load naive buffer model in '" << model_dir << "' successfully"; } +void LoadModelNaiveFromMemory(const std::string &model_buffer, + const std::string ¶m_buffer, + Scope *scope, + cpp::ProgramDesc *cpp_prog) { + CHECK(cpp_prog); + CHECK(scope); + cpp_prog->ClearBlocks(); + + // Load model + + std::string prog_path = model_buffer; + + naive_buffer::BinaryTable table; + table.LoadFromMemory(prog_path.c_str(), prog_path.length()); + + naive_buffer::proto::ProgramDesc nb_proto_prog(&table); + nb_proto_prog.Load(); + naive_buffer::ProgramDesc nb_prog(&nb_proto_prog); + + // Transform to cpp::ProgramDesc + TransformProgramDescAnyToCpp(nb_prog, cpp_prog); + + // Load Params + // NOTE: Only main block be used now. + // only combined Params are supported in Loading Model from memory + std::string combined_params_path = param_buffer; + LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog, true); + +#ifdef LITE_WITH_NPU + LOG(FATAL) << "load from memory is not supported by NPU"; +#endif + + VLOG(4) << "Load model from naive buffer memory successfully"; +} + } // namespace lite } // namespace paddle diff --git a/lite/model_parser/model_parser.h b/lite/model_parser/model_parser.h index 5592ff9cf70c8836b6a24945215aabe4a98efcce..81be2579e3932d7165480afd5bb89f567155cf36 100644 --- a/lite/model_parser/model_parser.h +++ b/lite/model_parser/model_parser.h @@ -94,14 +94,15 @@ void LoadParamNaive(const std::string& path, lite::Scope* scope, const std::string& name); -void LoadCombinedParamsNaive(const std::string& path, - lite::Scope* scope, - const cpp::ProgramDesc& cpp_prog); - void LoadModelNaive(const std::string& model_dir, lite::Scope* scope, cpp::ProgramDesc* prog, bool combined = true); +void LoadModelNaiveFromMemory(const std::string& model_buffer, + const std::string& param_buffer, + lite::Scope* scope, + cpp::ProgramDesc* cpp_prog); + } // namespace lite } // namespace paddle diff --git a/lite/model_parser/model_parser_test.cc b/lite/model_parser/model_parser_test.cc index cca70142450816e6db5f5ae463fc1b5fcc3ef36d..58083027849cc007bce80bd10004d0a13259fda7 100644 --- a/lite/model_parser/model_parser_test.cc +++ b/lite/model_parser/model_parser_test.cc @@ -121,12 +121,17 @@ TEST(ModelParser, SaveModelNaive) { SaveModelNaive(save_pb_model_path, scope, prog); } -TEST(ModelParser, LoadModelNaive) { +TEST(ModelParser, LoadModelNaiveFromMemory) { CHECK(!FLAGS_model_dir.empty()); cpp::ProgramDesc prog; Scope scope; - const std::string model_path = FLAGS_model_dir + ".saved.naive"; - LoadModelNaive(model_path, &scope, &prog); + + auto model_path = std::string(FLAGS_model_dir) + ".saved.naive/__model__.nb"; + auto params_path = std::string(FLAGS_model_dir) + ".saved.naive/param.nb"; + std::string model_buffer = lite::ReadFile(model_path); + std::string params_buffer = lite::ReadFile(params_path); + + LoadModelNaiveFromMemory(model_buffer, params_buffer, &scope, &prog); } } // namespace lite diff --git a/lite/model_parser/naive_buffer/naive_buffer.cc b/lite/model_parser/naive_buffer/naive_buffer.cc index 02860630a5e7d5852b85d38bccd0aa415ed4bb7e..cefaf0c28a34a70c095362e9972c9ef99d5fa80c 100644 --- a/lite/model_parser/naive_buffer/naive_buffer.cc +++ b/lite/model_parser/naive_buffer/naive_buffer.cc @@ -66,6 +66,14 @@ void BinaryTable::LoadFromFile(const std::string &filename) { is_mutable_mode_ = false; } +void BinaryTable::LoadFromMemory(const char *buffer, size_t buffer_size) { + // get buffer + bytes_.resize(buffer_size); + memcpy(reinterpret_cast(&bytes_[0]), buffer, buffer_size); + // Set readonly. + is_mutable_mode_ = false; +} + void StringBuilder::Save() { // memory format: [size][string data] uint64_t mem_size = sizeof(uint64_t) + data_.size(); diff --git a/lite/model_parser/naive_buffer/naive_buffer.h b/lite/model_parser/naive_buffer/naive_buffer.h index 4877b5ccd94f8725b0818f2015974b44548023a6..e2e2f7fb1ea3cb5b226bf09bd16074f51e171c75 100644 --- a/lite/model_parser/naive_buffer/naive_buffer.h +++ b/lite/model_parser/naive_buffer/naive_buffer.h @@ -63,6 +63,7 @@ struct BinaryTable { void SaveToFile(const std::string& filename) const; void LoadFromFile(const std::string& filename); + void LoadFromMemory(const char* buffer, size_t buffer_size); }; /* diff --git a/lite/utils/io.h b/lite/utils/io.h index ddd7e39b0d1d3e0b425cff1b31641a6a145f7bfa..98a0f39b084c1ec0767299501f6f359dab2017b3 100644 --- a/lite/utils/io.h +++ b/lite/utils/io.h @@ -43,5 +43,14 @@ static void MkDirRecur(const std::string& path) { #endif } +// read buffer from file +static std::string ReadFile(const std::string& filename) { + std::ifstream ifile(filename.c_str()); + std::ostringstream buf; + char ch; + while (buf && ifile.get(ch)) buf.put(ch); + return buf.str(); +} + } // namespace lite } // namespace paddle