diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index 1afc90aed3efb879553803d331603e9808ac9a19..210360f4ca75c7ff116e2fb9bc0a553383486e23 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -29,6 +29,7 @@ limitations under the License. */ #include "framework/scope.h" #include "framework/tensor.h" #include "memory/t_malloc.h" +#include "pass/memory_optimize.h" #include "pass/model_obfuscate.h" #ifdef PADDLE_MOBILE_CL #include "framework/cl/cl_image.h" @@ -66,9 +67,8 @@ Executor::Executor(const Program &program, #if !defined(PADDLE_MOBILE_FPGA) && !defined(PADDLE_MOBILE_FPGA_KD) && \ !defined(PADDLE_MOBILE_CL) if (config_.memory_optimization_level != NoMemoryOptimization) { - memoryOpt_ = std::make_shared(); - (*memoryOpt_)(program_desc_.get(), program_.scope.get(), - config_.memory_optimization_level); + pass::MemoryOptPass()(program_desc_.get(), program_.scope.get(), + config_.memory_optimization_level); } #endif // resize feed and fetch list @@ -296,34 +296,32 @@ static void ClearNoPersistableTensorArray(const framework::ProgramDesc *program, template void Executor::InitNoPersistableMemory(const Tensor &input_tensor) { + if (input_tensor.dims().size() != 4) { + return; + } for (const auto &block : program_desc_->Blocks()) { for (const auto &var_desc : block->Vars()) { auto var = program_.scope->Var(var_desc->Name()); - auto tensor = var->template GetMutable(); - if (var_desc->Persistable()) { - if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { - var->template GetMutable(); - continue; - } - } else { - if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) { + if (!var_desc->Persistable() && + var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) { + DLOG << "InitNoPersistableMemory var " << var_desc->Name(); + auto tensor = var->template GetMutable(); + if (tensor->IsInitialized() && tensor->dims().size() == 4) { + DLOG << "var's tensor is Initialized or dims size != 4"; DDim tensor_dim = tensor->dims(); DDim new_dim = make_ddim({tensor_dim[0], tensor_dim[1], input_tensor.dims()[2], input_tensor.dims()[3]}); tensor->Resize(new_dim); - tensor->template mutable_data(); + tensor->template mutable_data_new(); + DLOG << "var's tensor dims " << tensor_dim; + DLOG << "var's tensor new dims " << new_dim; } else { - PADDLE_MOBILE_THROW_EXCEPTION("Unsupported var type `%d`", - var_desc->Type()); + DLOG << "var's tensor is not Initialized ???"; } } } } - - std::shared_ptr output = GetOutput("fetch"); - output->Resize(input_tensor.dims()); - output->mutable_data(); } template @@ -411,7 +409,9 @@ void Executor::SetInput(const Tensor &input, target.ShareDataWith(input); if (feed_indices_.size() == 1) { auto &dim = input.dims(); - shouldAdjustMemory_ = (product(dim) < 0.9 * product(input_dim_last_)); + if (lod_mode_ && product(dim) < 0.9 * product(input_dim_last_)) { + InitNoPersistableMemory(target); + } input_dim_has_changed_ = input_dim_last_ != dim; input_dim_last_ = static_cast(dim); } @@ -433,7 +433,9 @@ void Executor::SetInput(const LoDTensor &input, target.set_lod(input.lod()); if (feed_indices_.size() == 1) { auto &dim = input.dims(); - shouldAdjustMemory_ = (product(dim) < 0.9 * product(input_dim_last_)); + if (lod_mode_ && product(dim) < 0.9 * product(input_dim_last_)) { + InitNoPersistableMemory(target); + } input_dim_has_changed_ = input_dim_last_ != dim; input_dim_last_ = static_cast(dim); } @@ -483,16 +485,7 @@ PMStatus Executor::Predict() { // clear all no persistable tensor array since write_to_array // is always push back a new tensor in the array ClearNoPersistableTensorArray(program_desc_.get(), program_.scope.get()); - if (lod_mode_ && input_dim_has_changed_) { - for (int i = 0; i < ops_of_block0_.size(); ++i) { - auto &op_handler = ops_of_block0_[i]; - op_handler->InferShape(); - } - if (memoryOpt_ != nullptr && shouldAdjustMemory_) { - shouldAdjustMemory_ = false; - memoryOpt_->AdjustMemory(); - } - } + #ifdef PADDLE_MOBILE_PROFILE std::vector profile(ops_of_block0_.size()); struct timespec ts; @@ -503,12 +496,12 @@ PMStatus Executor::Predict() { #ifdef PADDLE_MOBILE_PROFILE clock_gettime(CLOCK_MONOTONIC, &ts); profile[op_index].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; -// if (lod_mode_ && input_dim_has_changed_) { -// op_handler->InferShape(); -// } #endif DLOG << i << "th, " << "run op: " << op_handler->Type(); + if (lod_mode_ && input_dim_has_changed_) { + op_handler->InferShape(); + } op_handler->Run(); #ifdef PADDLE_MOBILE_PROFILE clock_gettime(CLOCK_MONOTONIC, &ts); diff --git a/src/framework/executor.h b/src/framework/executor.h index d898c81264b21e276256c7c12a814df595b8f021..81b37734d673304c2a303ca9024aea9fb5c543d5 100644 --- a/src/framework/executor.h +++ b/src/framework/executor.h @@ -27,7 +27,6 @@ limitations under the License. */ #include "framework/program/program.h" #include "framework/tensor.h" #include "framework/type_trait.h" -#include "pass/memory_optimize.h" namespace paddle_mobile { namespace framework { @@ -105,9 +104,6 @@ class Executor { DDim input_dim_last_; bool input_dim_has_changed_ = true; - bool shouldAdjustMemory_ = false; - std::shared_ptr memoryOpt_; - #ifdef PADDLE_MOBILE_PROFILE typedef typename DtypeTensorTrait::gtype ProfileTensorType; diff --git a/src/framework/tensor.h b/src/framework/tensor.h index 6c96400bf130366f9759c3da8d0e2cc5f2f4b655..93f11d84024033805baef1d1e9073585f1c75729 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -104,14 +104,27 @@ class Tensor : public TensorBase { return *this; } - inline void mutable_data_new() { + template + inline T *mutable_data_new() { + static_assert(std::is_pod::value, "T must be POD"); + const kTypeId_t type = type_id().hash_code(); + if (holder_ != nullptr) { - PADDLE_MOBILE_ENFORCE(numel() >= 0, "the Tensor's numel must >=0.") - int64_t size = numel() * SizeOfType(holder_->type()); - if (holder_->size() != size + offset_) { - holder_->realloc(size + offset_); + holder_->set_type(type); + } + + PADDLE_MOBILE_ENFORCE(numel() >= 0, "the Tensor's numel must >=0.") + int64_t size = numel() * SizeOfType(type); + if (holder_ == nullptr || holder_->size() != size + offset_) { + if (holder_ == nullptr) { + holder_.reset(new PlaceholderImpl(size, type)); + } else { + holder_->realloc(size); } + offset_ = 0; } + return reinterpret_cast(reinterpret_cast(holder_->ptr()) + + offset_); } inline void *mutable_data(const kTypeId_t type) { diff --git a/src/pass/memory_optimize.cpp b/src/pass/memory_optimize.cpp index 4c8f3a3af5a905cafb14607c92e81f91300882a6..d9cfa1389955ed503f0aae12e3251e01d2fe9a13 100644 --- a/src/pass/memory_optimize.cpp +++ b/src/pass/memory_optimize.cpp @@ -57,7 +57,6 @@ void MemoryOptPass::operator()( AppendBlockVars(block.get()); reused_nodes_.clear(); - memoryDeputies_.clear(); // collect all not persistable variables, and accumulate // it's reference count std::stack empty_var_nodes; @@ -157,33 +156,15 @@ void MemoryOptPass::operator()( auto *reuse_tensor = reused_var->template GetMutable(); reuse_tensor->mutable_data(); - framework::Variable *deputyVar; - int64_t varSize = 0; for (const auto &node : list) { DLOG << node->name; auto *var = scope->Var(node->name); auto *tensor = var->template GetMutable(); tensor->ShareHolderWith(*reuse_tensor); - if (tensor->numel() > varSize) { - varSize = tensor->numel(); - deputyVar = var; - } - } - if (deputyVar) { - memoryDeputies_.push_back(deputyVar); } } } } -void MemoryOptPass::AdjustMemory() { - for (auto &deputy : memoryDeputies_) { - if (deputy->IsType()) { - auto *tensor = deputy->template GetMutable(); - tensor->mutable_data_new(); - } - } -} - } // namespace pass } // namespace paddle_mobile diff --git a/src/pass/memory_optimize.h b/src/pass/memory_optimize.h index a9f02e23f8f2a37434665d6152107735949be7f3..f0171c5ba6951ace2efac2fc5840b8878df3d1de 100644 --- a/src/pass/memory_optimize.h +++ b/src/pass/memory_optimize.h @@ -51,14 +51,11 @@ class MemoryOptPass : public PassBase { VarNode *CreateNode(const std::string name); - void AdjustMemory(); - private: std::stack analysis_nodes_; std::vector> reused_nodes_; std::unordered_map created_nodes_; std::unordered_map block_vars_; - std::vector memoryDeputies_; }; } // namespace pass