diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index f6c4678370d6430a4a6f99711207be5a22b759a5..751d04fc6502d5376a7a5254afc83216d9dd2ee6 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -238,7 +238,7 @@ void Executor::InitCombineMemory() { template void Executor::InitNoPersistableMemory( - const LoDTensor &input_tensor) { + const Tensor &input_tensor) { for (const auto &block : program_desc_->Blocks()) { for (const auto &var_desc : block->Vars()) { auto var = program_.scope->Var(var_desc->Name()); @@ -336,9 +336,9 @@ void Executor::SetInput(const Tensor &input, auto *target_tensor = target_var->template GetMutable(); if (config_.load_when_predict) { - if (target_tensor->IsInitialized() && - target_tensor->dims() != input.dims()) { - InitNoPersistableMemory(*target_tensor); + if (input_dim_last_ != input.dims()) { + InitNoPersistableMemory(input); + input_dim_last_ = input.dims(); } } @@ -355,9 +355,9 @@ void Executor::SetInput(const LoDTensor &input, auto *target_tensor = target_var->template GetMutable(); if (config_.load_when_predict) { - if (target_tensor->IsInitialized() && - target_tensor->dims() != input.dims()) { + if (input_dim_last_ != input.dims()) { InitNoPersistableMemory(*target_tensor); + input_dim_last_ = input.dims(); } } diff --git a/src/framework/executor.h b/src/framework/executor.h index e77df5174ca60600ae7938b8bfd3bca5b2b9c9f3..edbfd5cdcc91b6f746f71a755311a1c80e24941c 100644 --- a/src/framework/executor.h +++ b/src/framework/executor.h @@ -65,7 +65,7 @@ class Executor { LoDTensor *tensor) const; void InitMemory(); void InitCombineMemory(); - void InitNoPersistableMemory(const LoDTensor &input_tensor); + void InitNoPersistableMemory(const Tensor &input_tensor); void LoadMemory(void **data, const std::shared_ptr var_desc, LoDTensor *tensor); #ifdef PADDLE_MOBILE_CL