From d4beefa185692b5f3c6df150d549638bbcb7e8fb Mon Sep 17 00:00:00 2001 From: xiebaiyuan Date: Sat, 1 Sep 2018 00:34:10 +0800 Subject: [PATCH] reflect load logics and add memoryload api. closed #876 --- src/framework/program/program.h | 2 + src/io/executor.cpp | 13 ++- src/io/loader.cpp | 130 ++++++++++++++++++++-------- src/io/loader.h | 5 ++ src/io/paddle_mobile.cpp | 26 ++++++ src/io/paddle_mobile.h | 13 +++ test/CMakeLists.txt | 4 + test/framework/test_load_memory.cpp | 67 ++++++++++++++ 8 files changed, 224 insertions(+), 36 deletions(-) create mode 100644 test/framework/test_load_memory.cpp diff --git a/src/framework/program/program.h b/src/framework/program/program.h index e500d50034..192328a567 100644 --- a/src/framework/program/program.h +++ b/src/framework/program/program.h @@ -31,6 +31,8 @@ class Program { std::string para_path; bool combined = false; bool quantification = false; + size_t combined_params_len; + const uint8_t *combined_params_buf; private: }; diff --git a/src/io/executor.cpp b/src/io/executor.cpp index 9100528705..8ef199c4ea 100644 --- a/src/io/executor.cpp +++ b/src/io/executor.cpp @@ -63,6 +63,8 @@ Executor::Executor(const framework::Program p, int batch_size, } Variable *variable_ptr = program_.scope->Var("batch_size"); variable_ptr[0].SetValue(batch_size); + PADDLE_MOBILE_ENFORCE(to_predict_program_ != nullptr, + "to_predict_program_ == NULL!"); const std::vector> blocks = to_predict_program_->Blocks(); #ifdef PADDLE_EXECUTOR_MULTITHREAD @@ -234,8 +236,15 @@ void Executor::InitMemory() { template void Executor::InitCombineMemory() { - LOG(kLOG_INFO) << " begin init combine memory"; - char *origin_data = Get_binary_data(program_.para_path); + char *origin_data; + if (program_.combined_params_buf && program_.combined_params_len) { + LOG(kLOG_INFO) << "use outter memory"; + origin_data = (char *)program_.combined_params_buf; + } else { + LOG(kLOG_INFO) << " begin init combine memory"; + origin_data = Get_binary_data(program_.para_path); + } + PADDLE_MOBILE_ENFORCE(origin_data != nullptr, "origin_data==nullptr!!!"); char *data = origin_data; for (const auto &block : to_predict_program_->Blocks()) { for (const auto &var_desc : block->Vars()) { diff --git a/src/io/loader.cpp b/src/io/loader.cpp index cdcecf02ab..7a0912106d 100644 --- a/src/io/loader.cpp +++ b/src/io/loader.cpp @@ -20,6 +20,62 @@ limitations under the License. */ namespace paddle_mobile { using framework::Variable; +/** + * muteandresize tensor as originProgramDesc and scope in loadParams + * + * @param originProgramDesc + * @param scope + */ +void InitMemoryFromProgram( + std::shared_ptr &originProgramDesc, + std::shared_ptr &scope) { + for (const auto &block : originProgramDesc.get()->Blocks()) { + for (const auto &var_desc : block->Vars()) { + auto var = scope.get()->Var(var_desc->Name()); + if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) { + if (var_desc->Persistable() && + var_desc->Type() != framework::VARTYPE_TYPE_FEED_MINIBATCH && + var_desc->Type() != framework::VARTYPE_TYPE_FETCH_LIST) { + auto dim = var_desc->Tensor_desc().Dims(); + auto tensor = var->GetMutable(); + tensor->Resize(framework::make_ddim(dim)); + } else { + auto dim = var_desc->Tensor_desc().Dims(); + PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0"); + dim[0] = 1; + auto tensor = var->GetMutable(); + tensor->Resize(framework::make_ddim(dim)); + } + } else { + // TODO(codeWorm): some. + } + } + } +} +/** + * fusion and print someinfos + * @tparam Dtype + * @tparam P + * @param optimize + * @param can_add_split + * @param program + * @param originProgramDesc + */ +template +void FusionAndPrintInfos( + bool &optimize, bool &can_add_split, framework::Program &program, + const std::shared_ptr &originProgramDesc) { + if (optimize) { + framework::ProgramOptimize program_optimize; + program.optimizeProgram = + program_optimize.FusionOptimize(originProgramDesc, can_add_split); + } + if (optimize) { + program.optimizeProgram->Description("optimize: "); + } else { + originProgramDesc->Description("program: "); + } +} static size_t ReadBuffer(const char *file_name, uint8_t **out) { FILE *fp; fp = fopen(file_name, "rb"); @@ -87,46 +143,52 @@ const framework::Program Loader::LoadProgram( framework::Program program; program.originProgram = originProgramDesc; program.quantification = quantification; - + program.combined_params_len = 0; + program.combined_params_buf = nullptr; auto scope = std::make_shared(); program.scope = scope; - for (const auto &block : originProgramDesc->Blocks()) { - for (auto var_desc : block->Vars()) { - auto var = scope->Var(var_desc->Name()); + // use originProgramDesc and scope to init tensors + InitMemoryFromProgram(originProgramDesc, scope); + // perform fusion and print infos + FusionAndPrintInfos(optimize, can_add_split, program, originProgramDesc); - if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) { - if (var_desc->Persistable() && - var_desc->Type() != framework::VARTYPE_TYPE_FEED_MINIBATCH && - var_desc->Type() != framework::VARTYPE_TYPE_FETCH_LIST) { - auto dim = var_desc->Tensor_desc().Dims(); - auto tensor = var->GetMutable(); - tensor->Resize(framework::make_ddim(dim)); - } else { - auto dim = var_desc->Tensor_desc().Dims(); - PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0"); - dim[0] = 1; - auto tensor = var->GetMutable(); - tensor->Resize(framework::make_ddim(dim)); - } - } else { - // TODO(codeWorm): some. - } - } - } + paddle_mobile__framework__proto__program_desc__free_unpacked(c_program, NULL); + return program; +} - if (optimize) { - framework::ProgramOptimize program_optimize; - program.optimizeProgram = - program_optimize.FusionOptimize(originProgramDesc, can_add_split); - } - if (optimize) { - program.optimizeProgram->Description("optimize: "); - } else { - originProgramDesc->Description("program: "); - } +template +const framework::Program Loader::LoadCombinedMemory( + size_t read_size, const uint8_t *buf, size_t combined_params_len, + const uint8_t *combined_params_buf, bool optimize, bool quantification) { + bool can_add_split = false; - paddle_mobile__framework__proto__program_desc__free_unpacked(c_program, NULL); + PaddleMobile__Framework__Proto__ProgramDesc *c_program; + PADDLE_MOBILE_ENFORCE(buf != nullptr, "read from __model__ is null"); + + c_program = paddle_mobile__framework__proto__program_desc__unpack( + nullptr, read_size, buf); + // + PADDLE_MOBILE_ENFORCE(c_program != nullptr, "program is null"); + // + DLOG << "n_ops: " << (*c_program->blocks)->n_ops; + // + + auto originProgramDesc = std::make_shared(c_program); + + framework::Program program; + program.combined = true; + program.originProgram = originProgramDesc; + program.quantification = quantification; + program.combined_params_len = combined_params_len; + program.combined_params_buf = combined_params_buf; + + auto scope = std::make_shared(); + program.scope = scope; + InitMemoryFromProgram(originProgramDesc, scope); + FusionAndPrintInfos(optimize, can_add_split, program, originProgramDesc); + paddle_mobile__framework__proto__program_desc__free_unpacked(c_program, + nullptr); return program; } diff --git a/src/io/loader.h b/src/io/loader.h index 512cee831f..505366793d 100644 --- a/src/io/loader.h +++ b/src/io/loader.h @@ -42,6 +42,11 @@ class Loader { bool optimize = false, bool quantification = false); + const framework::Program LoadCombinedMemory( + size_t model_len, const uint8_t *model_buf, size_t combined_params_len, + const uint8_t *combined_params_buf, bool optimize = false, + bool quantification = false); + private: const framework::Program LoadProgram(const std::string &model_path, bool optimize = false, diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index 5e2e209d64..420a35c213 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -64,6 +64,32 @@ bool PaddleMobile::Load(const std::string &model_path, return true; } +template +bool PaddleMobile::LoadCombinedMemory( + size_t model_len, const uint8_t *model_buf, size_t combined_params_len, + const uint8_t *combined_params_buf) { + int batch_size = 1; + bool optimise = true; + bool quantification = false; + + if (loader_.get() == nullptr) { + loader_ = std::make_shared>(); + } else { + LOG(kLOG_INFO) << "loader inited"; + } + + if (executor_.get() == nullptr) { + executor_ = std::make_shared>( + loader_->LoadCombinedMemory(model_len, model_buf, combined_params_len, + combined_params_buf, optimise, + quantification), + batch_size, optimise); + } else { + LOG(kLOG_INFO) << "executor inited"; + } + + return true; +} template std::shared_ptr PaddleMobile::Predict( const framework::Tensor &t) { diff --git a/src/io/paddle_mobile.h b/src/io/paddle_mobile.h index 5dc3ccb21d..2617407d0f 100644 --- a/src/io/paddle_mobile.h +++ b/src/io/paddle_mobile.h @@ -66,6 +66,19 @@ class PaddleMobile { std::vector Predict(const std::vector &input, const std::vector &dims); + /** + * 从内存加载model 以及 combinedparams的接口 + * + * @param model_len model 文件的内存大小 + * @param model_buf model文件的内存 + * @param combined_params_len params文件的内存大小 + * @param combined_params_buf params文件的内存 + * @return + */ + bool LoadCombinedMemory(size_t model_len, const uint8_t *model_buf, + size_t combined_params_len, + const uint8_t *combined_params_buf); + void Clear(); ~PaddleMobile(); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 8f92b6dab9..f2229a1bfc 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -117,6 +117,10 @@ else () ADD_EXECUTABLE(test-load framework/test_load.cpp) target_link_libraries(test-load paddle-mobile) + # gen test log + ADD_EXECUTABLE(test-loadmemory framework/test_load_memory.cpp) + target_link_libraries(test-loadmemory paddle-mobile) + ADD_EXECUTABLE(test-inference-api framework/test_inference_api.cpp) target_link_libraries(test-inference-api paddle-mobile) diff --git a/test/framework/test_load_memory.cpp b/test/framework/test_load_memory.cpp new file mode 100644 index 0000000000..4be7aaa82f --- /dev/null +++ b/test/framework/test_load_memory.cpp @@ -0,0 +1,67 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "../test_helper.h" +#include "../test_include.h" +static size_t ReadBuffer(const char *file_name, uint8_t **out) { + FILE *fp; + fp = fopen(file_name, "rb"); + PADDLE_MOBILE_ENFORCE(fp != nullptr, " %s open failed !", file_name); + fseek(fp, 0, SEEK_END); + auto size = static_cast(ftell(fp)); + rewind(fp); + DLOG << "model size: " << size; + *out = reinterpret_cast(malloc(size)); + size_t cur_len = 0; + size_t nread; + while ((nread = fread(*out + cur_len, 1, size - cur_len, fp)) != 0) { + cur_len += nread; + } + fclose(fp); + return cur_len; +} + +static char *Get_binary_data(std::string filename) { + FILE *file = fopen(filename.c_str(), "rb"); + PADDLE_MOBILE_ENFORCE(file != nullptr, "can't open file: %s ", + filename.c_str()); + fseek(file, 0, SEEK_END); + int64_t size = ftell(file); + PADDLE_MOBILE_ENFORCE(size > 0, "size is too small"); + rewind(file); + auto *data = new char[size]; + size_t bytes_read = fread(data, 1, size, file); + PADDLE_MOBILE_ENFORCE(bytes_read == size, + "read binary file bytes do not match with fseek"); + fclose(file); + return data; +} + +int main() { + paddle_mobile::PaddleMobile paddle_mobile; + auto model_path = std::string(g_genet_combine) + "/model"; + auto params_path = std::string(g_genet_combine) + "/params"; + uint8_t *bufModel = nullptr; + size_t sizeBuf = ReadBuffer(model_path.c_str(), &bufModel); + uint8_t *bufParams = nullptr; + + DLOG << "sizeBuf: " << sizeBuf; + size_t sizeParams = ReadBuffer(params_path.c_str(), &bufParams); + DLOG << "sizeParams: " << sizeParams; + + paddle_mobile.LoadCombinedMemory(sizeBuf, bufModel, sizeParams, bufParams); + return 0; +} -- GitLab