diff --git a/src/framework/context.h b/src/framework/context.h index 535bb759fec7534268f361eb9e2b0c6946024022..d38e1e3b5625b9151cc0c8c4ec41ce66080dd545 100644 --- a/src/framework/context.h +++ b/src/framework/context.h @@ -68,7 +68,7 @@ struct CPUContext { }; inline void set_global_num_threads(int threads) { - CPUContext::Context()->num_threads = threads; + CPUContext::Context()->set_num_threads(threads); } inline int get_global_num_threads() { diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index 8eee94ce228c195bda5657eecb3570348001cd47..a15c0e6b4e73e6132c5118379dc7ffb5ec75f0a3 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -663,14 +663,18 @@ void Executor::InitNoPersistableMemory( output->Resize(input_tensor.dims()); output->mutable_data(); } + template <> void Executor::SetInput(const Tensor &input, const std::string &var_name) { - auto *target_var = program_.scope->FindVar(var_name); - PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist", - var_name.c_str()); + int index = 0; + if (feed_indices_.find(var_name) != feed_indices_.end()) { + index = feed_indices_.find(var_name)->second; + } + auto *feed_var = program_.scope->Var("feed"); + framework::LoDTensor *target_tensor = + &(feed_var->template GetMutable()->at(index)); - auto *target_tensor = target_var->template GetMutable(); DLOG << "config_.load_when_predict " << config_.load_when_predict; DLOG << "target_tensor->IsInitialized() " << target_tensor->IsInitialized(); DLOG << "target_tensor->dims() " << target_tensor->dims(); @@ -781,7 +785,7 @@ void Executor::InitMemory() { if (var_desc->Persistable()) { CLImage *cl_image = nullptr; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { - var->template GetMutable(); + var->template GetMutable(); continue; } else { cl_image = var->template GetMutable(); @@ -849,7 +853,7 @@ void Executor::InitCombineMemory() { if (var_desc->Persistable()) { CLImage *cl_image = nullptr; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { - var->template GetMutable(); + var->template GetMutable(); continue; } else { cl_image = var->template GetMutable();