diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index 26c09eb8eb9db6f30e5d2e3ce332a8a62975fa40..9ed974151fb4b29b5396571cde7747c01709c24f 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -33,11 +33,6 @@ void FetchOpHandle::Wait(platform::DeviceContext *waited_dev) { } void FetchOpHandle::WaitAndMergeCPUTensors() const { - // Wait fetch stream done. - for (auto &ctx : dev_ctx_) { - ctx.second->Wait(); - } - std::vector tensors_ptr; tensors_ptr.reserve(tensors_.size()); for (auto &t : tensors_) { @@ -72,6 +67,8 @@ void FetchOpHandle::RunImpl() { tensors_[i].ShareDataWith(t); tensors_[i].set_lod(t.lod()); } + + this->WaitAndMergeCPUTensors(); } } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 7d1f7e46b8435ec0ef1913ea70d9a8f7a6734aac..7cfd66837966836774113229da3900c333f1c962 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -96,12 +96,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( for (auto *var : vars) { op->AddInput(var); } - - dummy_vars.emplace_back(); - auto *var = &dummy_vars.back(); - var->generated_op_ = nullptr; - op->AddOutput(var); - InsertPendingVar(*var); InsertPendingOp(*op); } @@ -176,8 +170,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( }; // Wait FetchOps. - for (auto &fetch_op : fetch_ops) { - fetch_op.WaitAndMergeCPUTensors(); + if (!fetch_ops.empty()) { sync_computation(); }