diff --git a/src/framework/loader.cpp b/src/framework/loader.cpp index a434314730eb40b7e4017050a84a7d9742934396..eb07e9f6155370880f6fb8e302a8e396df17954d 100644 --- a/src/framework/loader.cpp +++ b/src/framework/loader.cpp @@ -82,6 +82,54 @@ void Loader::InitMemoryFromProgram( } } } +template <> +const Program +Loader::LoadCombinedMemory( + size_t read_size, const uint8_t *buf, size_t combined_params_len, + uint8_t *combined_params_buf, bool optimize, bool quantification) { + bool can_add_split = false; + + PaddleMobile__Framework__Proto__ProgramDesc *c_program; + PADDLE_MOBILE_ENFORCE(buf != nullptr, "read from __model__ is null"); + + c_program = paddle_mobile__framework__proto__program_desc__unpack( + nullptr, read_size, buf); + // + PADDLE_MOBILE_ENFORCE(c_program != nullptr, "program is null"); + // + DLOG << "n_ops: " << (*c_program->blocks)->n_ops; + // + + auto originProgramDesc = std::make_shared(c_program); + + Program program; + program.combined = true; + program.originProgram = originProgramDesc; + program.quantification = quantification; + program.combined_params_len = combined_params_len; + program.combined_params_buf = combined_params_buf; + + auto scope = std::make_shared(); + program.scope = scope; + InitMemoryFromProgram(originProgramDesc, scope); + if (optimize) { + ProgramOptimize program_optimize; + program.optimizeProgram = + program_optimize.FusionOptimize(originProgramDesc, can_add_split); + if (!program.optimizeProgram) { + program.optimizeProgram = originProgramDesc; + } + } + if (optimize) { + program.optimizeProgram->Description("optimize: "); + } else { + originProgramDesc->Description("program: "); + } + paddle_mobile__framework__proto__program_desc__free_unpacked(c_program, + nullptr); + return program; +} + #endif /** diff --git a/test/framework/test_load_memory_inference_api.cpp b/test/framework/test_load_memory_inference_api.cpp index 05d51910172547c6dab7adc8231663be55c916bf..5b2773f8f1a21c3b9253b34fc5c18cd64ece27e7 100644 --- a/test/framework/test_load_memory_inference_api.cpp +++ b/test/framework/test_load_memory_inference_api.cpp @@ -55,11 +55,11 @@ static char *Get_binary_data(std::string filename) { paddle_mobile::PaddleMobileConfig GetConfig() { paddle_mobile::PaddleMobileConfig config; config.precision = paddle_mobile::PaddleMobileConfig::FP32; - config.device = paddle_mobile::PaddleMobileConfig::kCPU; + config.device = paddle_mobile::PaddleMobileConfig::kGPU_CL; const std::shared_ptr &memory_pack = std::make_shared(); - auto model_path = std::string(g_genet_combine) + "/model"; - auto params_path = std::string(g_genet_combine) + "/params"; + auto model_path = std::string(g_mobilenet_combined) + "/model"; + auto params_path = std::string(g_mobilenet_combined) + "/params"; memory_pack->model_size = ReadBuffer(model_path.c_str(), &memory_pack->model_buf); std::cout << "sizeBuf: " << memory_pack->model_size << std::endl;