diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f841b3b7fa84c5a69fe510b5fc239d949c2212b5..0f9bc869725d496ac46b8aea704269c141ba6816 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -482,7 +482,6 @@ void ParallelExecutor::ConstructDependencyGraph( bool is_forwarding = true; for (auto *op : main_program.Block(0).AllOps()) { bool change_forward = false; - if (!is_forwarding) { // FIXME(yy): Do not hard code like this if (op->OutputArgumentNames().size() == 1 && @@ -573,7 +572,7 @@ void ParallelExecutor::ConstructDependencyGraph( Dependency graph has been constructed. However, there are still data harzaeds need to be handled. */ - PolishGraphToSupportDataHarzaeds(); + PolishGraphToSupportDataHazards(); } /** @@ -583,7 +582,7 @@ void ParallelExecutor::ConstructDependencyGraph( * * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) */ -void ParallelExecutor::PolishGraphToSupportDataHarzaeds() const { +void ParallelExecutor::PolishGraphToSupportDataHazards() const { for (auto &place_pair : member_->vars_) { for (auto &name_pair : place_pair.second) { if (name_pair.second.size() <= 1) { @@ -813,6 +812,13 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, fetch_ops.clear(); *member_->global_scope_->Var(fetched_var_name)->GetMutable() = fetched_data->tensors_; + + // FIXME: + // It could be optimized by using multiple events in an operator. + // Manually sync computation during iter. + for (auto &p : member_->places_) { + platform::DeviceContextPool::Instance().Get(p)->Wait(); + } } void ParallelExecutor::RunOp( diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 03bf60b8bc44670b6b9e1a9b318655838d7d22e4..cb93c0cd4103813463f39ec7d3d51debbd6e15f6 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -65,7 +65,7 @@ class ParallelExecutor { std::unordered_map>& pending_vars, OpHandle* op) const; - void PolishGraphToSupportDataHarzaeds() const; + void PolishGraphToSupportDataHazards() const; }; } // namespace framework