diff --git a/src/framework/program/program-optimize/program_optimize.cpp b/src/framework/program/program-optimize/program_optimize.cpp index 15724523ded18e14cecf5d5aacf506992dadb3b4..2e802b120a6effd33bbe68048123c33f76a8aee8 100644 --- a/src/framework/program/program-optimize/program_optimize.cpp +++ b/src/framework/program/program-optimize/program_optimize.cpp @@ -106,11 +106,14 @@ std::shared_ptr ProgramOptimize::FushionOptimize( } std::vector> op_descs; - for (int m = 0; m < nodes.size(); ++m) { - auto &node = nodes[m]; - op_descs.push_back(node->op_desc_); + if (add_split) { + GenerateOps(&op_descs, begin_node.get(), add_split); + } else { + for (int m = 0; m < nodes.size(); ++m) { + auto &node = nodes[m]; + op_descs.push_back(node->op_desc_); + } } - // GenerateOps(&op_descs, begin_node.get()); block->ops_ = op_descs; } @@ -267,12 +270,12 @@ void ProgramOptimize::GenerateOps( } void ProgramOptimize::GenerateOps( - std::vector> *op_descs, - Node *begin_node) { + std::vector> *op_descs, Node *begin_node, + bool can_add_split) { // std::vector> *op_desc, // Node *input_node, Node *current_node, bool adding_thread, int // thread_num - if (false) { + if (can_add_split) { this->GenerateOps(op_descs, begin_node, begin_node, false, -1, nullptr); } else { this->GenerateOps(op_descs, begin_node, begin_node); diff --git a/src/framework/program/program-optimize/program_optimize.h b/src/framework/program/program-optimize/program_optimize.h index 93943cf83951565d91f67bfa77881dbcb130278d..ae632da4bdf004f23e7dab86ab06a4e007fdb75b 100644 --- a/src/framework/program/program-optimize/program_optimize.h +++ b/src/framework/program/program-optimize/program_optimize.h @@ -34,7 +34,7 @@ class ProgramOptimize { int current_block_; std::vector> new_blocks_; void GenerateOps(std::vector> *op_descs, - Node *begin_node); + Node *begin_node, bool can_add_split); void GenerateOps(std::vector> *op_desc, Node *input_node, Node *current_node); void GenerateOps(std::vector> *op_desc, diff --git a/src/io/io.cpp b/src/io/io.cpp index 019770399e29ea8bdd896b2348a23c09a5d27a95..c60113bb7882b7482cf5a23e4ad48adb6ec63de8 100644 --- a/src/io/io.cpp +++ b/src/io/io.cpp @@ -76,8 +76,9 @@ static size_t ReadBuffer(const char *file_name, uint8_t **out) { template const framework::Program Loader::Load( - const std::string &dirname, bool optimize) { - auto program = this->LoadProgram(dirname + "/__model__", optimize); + const std::string &dirname, bool optimize, bool can_add_split) { + auto program = + this->LoadProgram(dirname + "/__model__", optimize, can_add_split); program.model_path = dirname; return program; } @@ -94,7 +95,7 @@ const framework::Program Loader::Load( template const framework::Program Loader::LoadProgram( - const std::string &model_path, bool optimize) { + const std::string &model_path, bool optimize, bool can_add_split) { std::string model_filename = model_path; PaddleMobile__Framework__Proto__ProgramDesc *c_program; uint8_t *buf = NULL; @@ -146,7 +147,7 @@ const framework::Program Loader::LoadProgram( if (optimize) { framework::ProgramOptimize program_optimize; program.optimizeProgram = - program_optimize.FushionOptimize(originProgramDesc); + program_optimize.FushionOptimize(originProgramDesc, can_add_split); } if (optimize) { program.optimizeProgram->Description("optimize: "); @@ -310,6 +311,7 @@ void Executor::InitMemory() { template void Executor::InitCombineMemory() { + LOG(kLOG_INFO) << " begin init combine memory"; char *origin_data = Get_binary_data(program_.para_path); char *data = origin_data; for (const auto &block : to_predict_program_->Blocks()) { @@ -330,6 +332,7 @@ void Executor::InitCombineMemory() { } } delete origin_data; + LOG(kLOG_INFO) << " end init combine memory "; } template diff --git a/src/io/io.h b/src/io/io.h index fb18ca0cc1768f5cfe39acfcba7d0117a67e1de5..a1fbf158c2b026336d363db512cb44fe58ee93db 100644 --- a/src/io/io.h +++ b/src/io/io.h @@ -35,7 +35,8 @@ class Loader { * @b 加载分开形式的 fluid 模型 * */ const framework::Program Load(const std::string &dirname, - bool optimize = false); + bool optimize = false, + bool can_add_split = false); /* * @b load combine format fluid mode @@ -47,7 +48,8 @@ class Loader { private: const framework::Program LoadProgram(const std::string &model_path, - bool optimize = false); + bool optimize = false, + bool can_add_split = false); }; template diff --git a/test/framework/test_load.cpp b/test/framework/test_load.cpp index 2300f05c99a122b352d888a45ca3c6ef082469ba..32d314826f8d6bd4e504b16cd78464d660919a30 100644 --- a/test/framework/test_load.cpp +++ b/test/framework/test_load.cpp @@ -19,9 +19,10 @@ int main() { paddle_mobile::Loader loader; // ../../../test/models/googlenet // ../../../test/models/mobilenet - auto program = loader.Load(g_resnet, true); - loader.Load(g_googlenet_combine + "/model", g_googlenet_combine + "/params", - true); + auto program = loader.Load(g_googlenet, true, true); + // loader.Load(g_googlenet_combine + "/model", g_googlenet_combine + + // "/params", + // true); program.originProgram->Description("program desc: "); return 0;