提交 67bd7f3f 编写于 作者: Z zhaojiaying01

run superresoltion on cpu

上级 9e6ce432
...@@ -238,7 +238,7 @@ void Executor<Device, T>::InitCombineMemory() { ...@@ -238,7 +238,7 @@ void Executor<Device, T>::InitCombineMemory() {
template <typename Device, typename T> template <typename Device, typename T>
void Executor<Device, T>::InitNoPersistableMemory( void Executor<Device, T>::InitNoPersistableMemory(
const LoDTensor &input_tensor) { const Tensor &input_tensor) {
for (const auto &block : program_desc_->Blocks()) { for (const auto &block : program_desc_->Blocks()) {
for (const auto &var_desc : block->Vars()) { for (const auto &var_desc : block->Vars()) {
auto var = program_.scope->Var(var_desc->Name()); auto var = program_.scope->Var(var_desc->Name());
...@@ -336,9 +336,9 @@ void Executor<Device, T>::SetInput(const Tensor &input, ...@@ -336,9 +336,9 @@ void Executor<Device, T>::SetInput(const Tensor &input,
auto *target_tensor = target_var->template GetMutable<LoDTensor>(); auto *target_tensor = target_var->template GetMutable<LoDTensor>();
if (config_.load_when_predict) { if (config_.load_when_predict) {
if (target_tensor->IsInitialized() && if (input_dim_last_ != input.dims()) {
target_tensor->dims() != input.dims()) { InitNoPersistableMemory(input);
InitNoPersistableMemory(*target_tensor); input_dim_last_ = input.dims();
} }
} }
...@@ -355,9 +355,9 @@ void Executor<Device, T>::SetInput(const LoDTensor &input, ...@@ -355,9 +355,9 @@ void Executor<Device, T>::SetInput(const LoDTensor &input,
auto *target_tensor = target_var->template GetMutable<LoDTensor>(); auto *target_tensor = target_var->template GetMutable<LoDTensor>();
if (config_.load_when_predict) { if (config_.load_when_predict) {
if (target_tensor->IsInitialized() && if (input_dim_last_ != input.dims()) {
target_tensor->dims() != input.dims()) {
InitNoPersistableMemory(*target_tensor); InitNoPersistableMemory(*target_tensor);
input_dim_last_ = input.dims();
} }
} }
......
...@@ -65,7 +65,7 @@ class Executor { ...@@ -65,7 +65,7 @@ class Executor {
LoDTensor *tensor) const; LoDTensor *tensor) const;
void InitMemory(); void InitMemory();
void InitCombineMemory(); void InitCombineMemory();
void InitNoPersistableMemory(const LoDTensor &input_tensor); void InitNoPersistableMemory(const Tensor &input_tensor);
void LoadMemory(void **data, const std::shared_ptr<VarDesc> var_desc, void LoadMemory(void **data, const std::shared_ptr<VarDesc> var_desc,
LoDTensor *tensor); LoDTensor *tensor);
#ifdef PADDLE_MOBILE_CL #ifdef PADDLE_MOBILE_CL
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册