提交 d7623790 编写于 作者: H hjchen2

Fix feed op crash for opencl

上级 aae862c1
...@@ -68,7 +68,7 @@ struct CPUContext { ...@@ -68,7 +68,7 @@ struct CPUContext {
}; };
inline void set_global_num_threads(int threads) { inline void set_global_num_threads(int threads) {
CPUContext::Context()->num_threads = threads; CPUContext::Context()->set_num_threads(threads);
} }
inline int get_global_num_threads() { inline int get_global_num_threads() {
......
...@@ -663,14 +663,18 @@ void Executor<GPU_CL, float>::InitNoPersistableMemory( ...@@ -663,14 +663,18 @@ void Executor<GPU_CL, float>::InitNoPersistableMemory(
output->Resize(input_tensor.dims()); output->Resize(input_tensor.dims());
output->mutable_data<float>(); output->mutable_data<float>();
} }
template <> template <>
void Executor<GPU_CL, float>::SetInput(const Tensor &input, void Executor<GPU_CL, float>::SetInput(const Tensor &input,
const std::string &var_name) { const std::string &var_name) {
auto *target_var = program_.scope->FindVar(var_name); int index = 0;
PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist", if (feed_indices_.find(var_name) != feed_indices_.end()) {
var_name.c_str()); index = feed_indices_.find(var_name)->second;
}
auto *feed_var = program_.scope->Var("feed");
framework::LoDTensor *target_tensor =
&(feed_var->template GetMutable<framework::LoDTensorArray>()->at(index));
auto *target_tensor = target_var->template GetMutable<LoDTensor>();
DLOG << "config_.load_when_predict " << config_.load_when_predict; DLOG << "config_.load_when_predict " << config_.load_when_predict;
DLOG << "target_tensor->IsInitialized() " << target_tensor->IsInitialized(); DLOG << "target_tensor->IsInitialized() " << target_tensor->IsInitialized();
DLOG << "target_tensor->dims() " << target_tensor->dims(); DLOG << "target_tensor->dims() " << target_tensor->dims();
...@@ -781,7 +785,7 @@ void Executor<GPU_CL, float>::InitMemory() { ...@@ -781,7 +785,7 @@ void Executor<GPU_CL, float>::InitMemory() {
if (var_desc->Persistable()) { if (var_desc->Persistable()) {
CLImage *cl_image = nullptr; CLImage *cl_image = nullptr;
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
var->template GetMutable<LoDTensor>(); var->template GetMutable<framework::LoDTensorArray>();
continue; continue;
} else { } else {
cl_image = var->template GetMutable<CLImage>(); cl_image = var->template GetMutable<CLImage>();
...@@ -849,7 +853,7 @@ void Executor<GPU_CL, float>::InitCombineMemory() { ...@@ -849,7 +853,7 @@ void Executor<GPU_CL, float>::InitCombineMemory() {
if (var_desc->Persistable()) { if (var_desc->Persistable()) {
CLImage *cl_image = nullptr; CLImage *cl_image = nullptr;
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
var->template GetMutable<LoDTensor>(); var->template GetMutable<framework::LoDTensorArray>();
continue; continue;
} else { } else {
cl_image = var->template GetMutable<CLImage>(); cl_image = var->template GetMutable<CLImage>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册