diff --git a/src/common/types.h b/src/common/types.h index c12e5b6a268f66f7fcf53d55d1f40a15093474e3..25bb9b994bc9e865eac210915fb9a2a4974163cf 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -107,6 +107,10 @@ enum PoolingType { AVG = 1, }; +struct PaddleMobileConfigInternal { + bool load_when_predict = false; +}; + extern const char *G_OP_TYPE_CONV; extern const char *G_OP_TYPE_BATCHNORM; extern const char *G_OP_TYPE_BOX_CODER; diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index 884f5200f2bd9ec1b86429b4d37c3e58ea16724e..f4aafb7ca86951402764cd5d723a5408eaf684f4 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -37,6 +37,12 @@ namespace framework { #pragma mark - executor +template +Executor::Executor(const Program &program, paddle_mobile::PaddleMobileConfigInternal config, int batch_size, + const bool use_optimize, const bool lod_mode): Executor(program, batch_size, use_optimize, lod_mode) { + config_ = config; +}; + template Executor::Executor(const Program &program, int batch_size, const bool use_optimize, const bool lod_mode) @@ -212,10 +218,16 @@ void Executor::InitCombineMemory() { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { continue; } + + DLOG << " init combine memory persistable: " << var_desc->Name(); + LoadMemory(reinterpret_cast(&data), var_desc, tensor); } else { if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) { + DLOG << " init combine memory no persistable in lod: " << var_desc->Name(); varInputMemory(var_desc, var, tensor); + } else { + DLOG << " init combine memory no persistable: " << var_desc->Name(); } } } @@ -226,6 +238,32 @@ void Executor::InitCombineMemory() { LOG(kLOG_INFO) << "init combine memory finish"; } +template +void Executor::InitNoPersistableMemory(const LoDTensor &input_tensor) { + for (const auto &block : program_desc_->Blocks()) { + for (const auto &var_desc : block->Vars()) { + auto var = program_.scope->Var(var_desc->Name()); + auto tensor = var->template GetMutable(); + if (var_desc->Persistable()) { + if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { + continue; + } + } else { + if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) { + DDim tensor_dim = tensor->dims(); + DDim new_dim = make_ddim({tensor_dim[0], tensor_dim[1], input_tensor.dims()[2], input_tensor.dims()[3]}); + tensor->template Resize(new_dim); + tensor->template mutable_data(); + } + } + } + } + + std::shared_ptr output = GetOutput("fetch"); + output->Resize(input_tensor.dims()); + output->mutable_data(); +} + template bool Executor::varInputMemory( const std::shared_ptr &var_desc, Variable *var, @@ -275,6 +313,7 @@ PMStatus Executor::Predict( template std::vector Executor::Predict(const std::vector &input, const std::vector &dims) { + Tensor feed_tensor(input, make_ddim(dims)); SetInput(feed_tensor, "feed"); std::vector output; @@ -293,7 +332,15 @@ void Executor::SetInput(const Tensor &input, auto *target_var = program_.scope->FindVar(var_name); PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist", var_name.c_str()); + auto *target_tensor = target_var->template GetMutable(); + + if (config_.load_when_predict) { + if (target_tensor->IsInitialized() && target_tensor->dims() != input.dims()) { + InitNoPersistableMemory(*target_tensor); + } + } + target_tensor->Resize(input.dims()); target_tensor->ShareDataWith(input); } @@ -301,10 +348,18 @@ void Executor::SetInput(const Tensor &input, template void Executor::SetInput(const LoDTensor &input, const std::string &var_name) { + auto *target_var = program_.scope->FindVar(var_name); PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist", var_name.c_str()); auto *target_tensor = target_var->template GetMutable(); + + if (config_.load_when_predict) { + if (target_tensor->IsInitialized() && target_tensor->dims() != input.dims()) { + InitNoPersistableMemory(*target_tensor); + } + } + target_tensor->Resize(input.dims()); target_tensor->ShareDataWith(input); target_tensor->set_lod(input.lod()); diff --git a/src/framework/executor.h b/src/framework/executor.h index 0301e9e7980484d0f6280e69d0c5370adba13745..d54e9da8a2e1b6ee998c6181a31f9fdaf5048464 100644 --- a/src/framework/executor.h +++ b/src/framework/executor.h @@ -32,6 +32,8 @@ namespace framework { template class Executor { public: + Executor(const Program &program, paddle_mobile::PaddleMobileConfigInternal config, int batch_size = 1, + const bool use_optimize = true, const bool lod_mode = false); Executor(const Program &program, int batch_size = 1, const bool use_optimize = true, const bool lod_mode = false); @@ -60,10 +62,13 @@ class Executor { protected: Executor() = default; + + bool varInputMemory(const std::shared_ptr &var_desc, Variable *var, LoDTensor *tensor) const; void InitMemory(); void InitCombineMemory(); + void InitNoPersistableMemory(const LoDTensor &input_tensor); void LoadMemory(void **data, const std::shared_ptr var_desc, LoDTensor *tensor); #ifdef PADDLE_MOBILE_CL @@ -73,14 +78,18 @@ class Executor { int batch_size_; bool use_optimize_; bool lod_mode_; + PaddleMobileConfigInternal config_ = PaddleMobileConfigInternal(); Program program_; std::shared_ptr program_desc_; - typedef std::shared_ptr> OperatorBasePtr; std::vector> ops_of_block_; // operators list std::vector ops_list_; + // for super resoltion + DDim input_dim_; + + #ifdef PADDLE_MOBILE_PROFILE struct ProfInfo { int tid = 0; diff --git a/src/framework/loader.h b/src/framework/loader.h index bd4dfa15565dbb8e9afce769b12fe23eb7a1a970..e2a10856c1b5d68a3dd34c43e86f60dc1f102cb3 100644 --- a/src/framework/loader.h +++ b/src/framework/loader.h @@ -25,6 +25,7 @@ namespace framework { template class Loader { public: + /* * @b load separate format fluid model * @b 加载分开存储的fluid模型 @@ -59,6 +60,7 @@ class Loader { void InitMemoryFromProgram( const std::shared_ptr &originProgramDesc, const std::shared_ptr &scope); + }; } // namespace framework diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index addaefad1466e7157a553bedbc869377723a9213..c1a8cc10c727a77ae5480a6f686cafc596dfd7cf 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -42,7 +42,7 @@ PMStatus PaddleMobile::Load(const std::string &dirname, if (executor_.get() == nullptr) { executor_ = std::make_shared>( - loader_->Load(dirname, optimize, quantification), batch_size, optimize, + loader_->Load(dirname, optimize, quantification), config_, batch_size, optimize, loddable); } else { LOG(kLOG_INFO) << "executor inited"; @@ -64,8 +64,7 @@ PMStatus PaddleMobile::Load(const std::string &model_path, if (executor_.get() == nullptr) { executor_ = std::make_shared>( - loader_->Load(model_path, para_path, optimize, quantification), - batch_size, optimize, loddable); + loader_->Load(model_path, para_path, optimize, quantification), config_, batch_size, optimize, loddable); } else { LOG(kLOG_INFO) << "executor inited"; } @@ -87,7 +86,7 @@ bool PaddleMobile::LoadCombinedMemory( executor_ = std::make_shared>( loader_->LoadCombinedMemory(model_len, model_buf, combined_params_len, combined_params_buf, optimize, - quantification), + quantification), config_, batch_size, optimize, loddable); } else { LOG(kLOG_INFO) << "executor inited"; diff --git a/src/io/paddle_mobile.h b/src/io/paddle_mobile.h index b98da215eb4dac5af4e424461f6a233ccf33a612..045243f82e5597a50dd4ab5a5d018e316034520f 100644 --- a/src/io/paddle_mobile.h +++ b/src/io/paddle_mobile.h @@ -33,9 +33,18 @@ limitations under the License. */ namespace paddle_mobile { + template class PaddleMobile { public: + + PaddleMobile(PaddleMobileConfigInternal config): config_(config){ +#ifndef PADDLE_MOBILE_CL + bool is_gpu = std::is_same, Device>::value; + PADDLE_MOBILE_ENFORCE(!is_gpu, "Please recompile with GPU_CL is on"); +#endif + } + PaddleMobile() { #ifndef PADDLE_MOBILE_CL bool is_gpu = std::is_same, Device>::value; @@ -100,6 +109,7 @@ class PaddleMobile { private: std::shared_ptr> loader_; std::shared_ptr> executor_; + PaddleMobileConfigInternal config_; }; } // namespace paddle_mobile diff --git a/test/net/test_super.cpp b/test/net/test_super.cpp index dd77de3fbd39d5e9d65fc7faa66b55c85833eaf1..a47ed1ba162626ccbb8cd06af940f17bfd1c1607 100644 --- a/test/net/test_super.cpp +++ b/test/net/test_super.cpp @@ -18,7 +18,10 @@ limitations under the License. */ #include "../test_include.h" int main() { - paddle_mobile::PaddleMobile paddle_mobile; + paddle_mobile::PaddleMobileConfigInternal config; + config.load_when_predict = true; + + paddle_mobile::PaddleMobile paddle_mobile(config); // paddle_mobile.SetThreadNum(4); auto time1 = paddle_mobile::time(); #ifdef PADDLE_MOBILE_CL @@ -27,7 +30,7 @@ int main() { auto isok = paddle_mobile.Load(std::string(g_super) + "/model", std::string(g_super) + "/params", true, false, - 1, true); + 1, false); // auto isok = paddle_mobile.Load(std::string(g_mobilenet_mul), true); if (isok) {