From c18c2f6ab01082e14e76fdbcf384f577239bcc0f Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 20 Mar 2018 12:15:06 +0800 Subject: [PATCH] Sync all computation streams at the end of run --- paddle/fluid/framework/parallel_executor.cc | 12 +++++++++--- paddle/fluid/framework/parallel_executor.h | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f841b3b7fa8..0f9bc869725 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -482,7 +482,6 @@ void ParallelExecutor::ConstructDependencyGraph( bool is_forwarding = true; for (auto *op : main_program.Block(0).AllOps()) { bool change_forward = false; - if (!is_forwarding) { // FIXME(yy): Do not hard code like this if (op->OutputArgumentNames().size() == 1 && @@ -573,7 +572,7 @@ void ParallelExecutor::ConstructDependencyGraph( Dependency graph has been constructed. However, there are still data harzaeds need to be handled. */ - PolishGraphToSupportDataHarzaeds(); + PolishGraphToSupportDataHazards(); } /** @@ -583,7 +582,7 @@ void ParallelExecutor::ConstructDependencyGraph( * * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) */ -void ParallelExecutor::PolishGraphToSupportDataHarzaeds() const { +void ParallelExecutor::PolishGraphToSupportDataHazards() const { for (auto &place_pair : member_->vars_) { for (auto &name_pair : place_pair.second) { if (name_pair.second.size() <= 1) { @@ -813,6 +812,13 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, fetch_ops.clear(); *member_->global_scope_->Var(fetched_var_name)->GetMutable() = fetched_data->tensors_; + + // FIXME: + // It could be optimized by using multiple events in an operator. + // Manually sync computation during iter. + for (auto &p : member_->places_) { + platform::DeviceContextPool::Instance().Get(p)->Wait(); + } } void ParallelExecutor::RunOp( diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 03bf60b8bc4..cb93c0cd410 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -65,7 +65,7 @@ class ParallelExecutor { std::unordered_map>& pending_vars, OpHandle* op) const; - void PolishGraphToSupportDataHarzaeds() const; + void PolishGraphToSupportDataHazards() const; }; } // namespace framework -- GitLab