From 19ddd6ee25b164df59ec33bd1348eb9c145f020e Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Fri, 15 Mar 2019 19:14:33 +0800 Subject: [PATCH] Fix feed op crash for opencl --- src/framework/context.h | 2 +- src/framework/executor.cpp | 16 ++++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/framework/context.h b/src/framework/context.h index 535bb759fe..d38e1e3b56 100644 --- a/src/framework/context.h +++ b/src/framework/context.h @@ -68,7 +68,7 @@ struct CPUContext { }; inline void set_global_num_threads(int threads) { - CPUContext::Context()->num_threads = threads; + CPUContext::Context()->set_num_threads(threads); } inline int get_global_num_threads() { diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index 8eee94ce22..a15c0e6b4e 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -663,14 +663,18 @@ void Executor::InitNoPersistableMemory( output->Resize(input_tensor.dims()); output->mutable_data(); } + template <> void Executor::SetInput(const Tensor &input, const std::string &var_name) { - auto *target_var = program_.scope->FindVar(var_name); - PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist", - var_name.c_str()); + int index = 0; + if (feed_indices_.find(var_name) != feed_indices_.end()) { + index = feed_indices_.find(var_name)->second; + } + auto *feed_var = program_.scope->Var("feed"); + framework::LoDTensor *target_tensor = + &(feed_var->template GetMutable()->at(index)); - auto *target_tensor = target_var->template GetMutable(); DLOG << "config_.load_when_predict " << config_.load_when_predict; DLOG << "target_tensor->IsInitialized() " << target_tensor->IsInitialized(); DLOG << "target_tensor->dims() " << target_tensor->dims(); @@ -781,7 +785,7 @@ void Executor::InitMemory() { if (var_desc->Persistable()) { CLImage *cl_image = nullptr; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { - var->template GetMutable(); + var->template GetMutable(); continue; } else { cl_image = var->template GetMutable(); @@ -849,7 +853,7 @@ void Executor::InitCombineMemory() { if (var_desc->Persistable()) { CLImage *cl_image = nullptr; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { - var->template GetMutable(); + var->template GetMutable(); continue; } else { cl_image = var->template GetMutable(); -- GitLab