未验证 提交 28013ef9 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #10019 from reyoung/feature/feed_fetch_tensor_on_cpu

Fix FetchTensor on CPU
...@@ -51,23 +51,23 @@ void FetchOpHandle::RunImpl() { ...@@ -51,23 +51,23 @@ void FetchOpHandle::RunImpl() {
auto *var = static_cast<VarHandle *>(input); auto *var = static_cast<VarHandle *>(input);
var->generated_op_->Wait(cpu_ctx); var->generated_op_->Wait(cpu_ctx);
} }
tensors_.resize(inputs_.size()); tensors_.resize(inputs_.size());
auto *var = static_cast<VarHandle *>(inputs_[0]); auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
auto &var_name = var->name_; auto &var_name = var_handle->name_;
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto &scopes = *local_scopes_; auto &scopes = *local_scopes_;
for (size_t i = 0; i < scopes.size(); ++i) { for (size_t i = 0; i < scopes.size(); ++i) {
auto &scope = scopes[i]; auto &scope = scopes[i];
auto &t = scope->FindVar(kLocalExecScopeName) auto *var =
->Get<Scope *>() scope->FindVar(kLocalExecScopeName)->Get<Scope *>()->FindVar(var_name);
->FindVar(var_name) PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope",
->Get<framework::LoDTensor>(); var_name);
if (platform::is_gpu_place(var->place_)) { auto &t = var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(t.place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i]); TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i]);
dev_ctxes_[t.place()]->Wait(); dev_ctxes_.at(t.place())->Wait();
#endif #endif
} else { } else {
tensors_[i].ShareDataWith(t); tensors_[i].ShareDataWith(t);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册