diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index bf9ec6a87fb88839087bab0e21c406343e437b79..d511ad48ca3cf794ea473ab5bc0d791d4c72eded 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -299,31 +299,26 @@ void Executor::InitNoPersistableMemory(const Tensor &input_tensor) { for (const auto &block : program_desc_->Blocks()) { for (const auto &var_desc : block->Vars()) { auto var = program_.scope->Var(var_desc->Name()); - auto tensor = var->template GetMutable(); - if (var_desc->Persistable()) { - if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { - var->template GetMutable(); - continue; - } - } else { - if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) { + if (!var_desc->Persistable() && + var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) { + DLOG << "InitNoPersistableMemory var " << var_desc->Name(); + auto tensor = var->template GetMutable(); + if (tensor->IsInitialized()) { + DLOG << "var's tensor is Initialized"; DDim tensor_dim = tensor->dims(); DDim new_dim = make_ddim({tensor_dim[0], tensor_dim[1], input_tensor.dims()[2], input_tensor.dims()[3]}); tensor->Resize(new_dim); - tensor->template mutable_data(); + tensor->template mutable_data_new(); + DLOG << "var's tensor dims " << tensor_dim; + DLOG << "var's tensor new dims " << new_dim; } else { - PADDLE_MOBILE_THROW_EXCEPTION("Unsupported var type `%d`", - var_desc->Type()); + DLOG << "var's tensor is not Initialized ???"; } } } } - - std::shared_ptr output = GetOutput("fetch"); - output->Resize(input_tensor.dims()); - output->mutable_data(); } template @@ -411,6 +406,9 @@ void Executor::SetInput(const Tensor &input, target.ShareDataWith(input); if (feed_indices_.size() == 1) { auto &dim = input.dims(); + if (lod_mode_ && product(dim) < 0.9 * product(input_dim_last_)) { + InitNoPersistableMemory(target); + } input_dim_has_changed_ = input_dim_last_ != dim; input_dim_last_ = static_cast(dim); } @@ -432,6 +430,9 @@ void Executor::SetInput(const LoDTensor &input, target.set_lod(input.lod()); if (feed_indices_.size() == 1) { auto &dim = input.dims(); + if (lod_mode_ && product(dim) < 0.9 * product(input_dim_last_)) { + InitNoPersistableMemory(target); + } input_dim_has_changed_ = input_dim_last_ != dim; input_dim_last_ = static_cast(dim); } diff --git a/src/framework/tensor.h b/src/framework/tensor.h index de70d557358234b98a984ec436370d42e5be69cc..93f11d84024033805baef1d1e9073585f1c75729 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -68,7 +68,8 @@ class Tensor : public TensorBase { Resize(ddim); auto type = type_id().hash_code(); int64_t size = numel() * SizeOfType(type); - holder_.reset(new PlaceholderImpl(size, type, (uint8_t *)input)); + holder_.reset( + new PlaceholderImpl(size, type, reinterpret_cast(input))); holder_->set_type(type); offset_ = 0; } @@ -103,6 +104,29 @@ class Tensor : public TensorBase { return *this; } + template + inline T *mutable_data_new() { + static_assert(std::is_pod::value, "T must be POD"); + const kTypeId_t type = type_id().hash_code(); + + if (holder_ != nullptr) { + holder_->set_type(type); + } + + PADDLE_MOBILE_ENFORCE(numel() >= 0, "the Tensor's numel must >=0.") + int64_t size = numel() * SizeOfType(type); + if (holder_ == nullptr || holder_->size() != size + offset_) { + if (holder_ == nullptr) { + holder_.reset(new PlaceholderImpl(size, type)); + } else { + holder_->realloc(size); + } + offset_ = 0; + } + return reinterpret_cast(reinterpret_cast(holder_->ptr()) + + offset_); + } + inline void *mutable_data(const kTypeId_t type) { if (holder_ != nullptr) { holder_->set_type(type); @@ -244,6 +268,12 @@ class Tensor : public TensorBase { size_ = size; } + virtual void realloc(size_t size) { + capatity_ = size; + ptr_.reset(static_cast(memory::Alloc(capatity_))); + size_ = size; + } + std::unique_ptr> ptr_; /*! the size of memory block. */ diff --git a/src/framework/tensor_base.h b/src/framework/tensor_base.h index 027f1165a08509431fd1281f7b05174a7c64b7cc..a7f4aa1b8acadb8cd15676b3584b431b00d383a3 100644 --- a/src/framework/tensor_base.h +++ b/src/framework/tensor_base.h @@ -117,6 +117,8 @@ class TensorBase { virtual void set_type(kTypeId_t type) = 0; virtual void resize(size_t size) = 0; + + virtual void realloc(size_t size) = 0; }; /**