From f4f4816b0c1ffdf7689523f732cd728c196e5aff Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 22 Feb 2019 16:26:50 +0800 Subject: [PATCH] fix gpu error test=develop --- .../details/async_ssa_graph_executor.cc | 1 + paddle/fluid/framework/parallel_executor.cc | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 5ce92ad8267..0780fb040a6 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -29,6 +29,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( graphs_(std::move(graphs)) { VLOG(3) << "build AsyncSSAGraphExecutor"; PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); + PADDLE_ENFORCE_EQ(graphs_.size, local_scopes_.size()); // set the correct size of thread pool to each device. strategy_.num_threads_ = strategy_.num_threads_ < places_.size() diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ecae729124c..cfd6609a4b1 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -261,10 +261,21 @@ ParallelExecutor::ParallelExecutor( // ncclOp std::vector> graphs; #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - std::unique_ptr graph = build_strategy.Apply( - main_program, member_->places_, loss_var_name, member_->local_scopes_, - member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_.get()); - graphs.push_back(std::move(graph)); + if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { + VLOG(3) << "use local async mode"; + for (size_t i = 0; i < member_->places_.size(); ++i) { + std::unique_ptr graph = build_strategy.Apply( + main_program, {member_->places_[i]}, loss_var_name, + {member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_, + member_->nccl_ctxs_.get()); + graphs.push_back(std::move(graph)); + } + } else { + std::unique_ptr graph = build_strategy.Apply( + main_program, member_->places_, loss_var_name, member_->local_scopes_, + member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_.get()); + graphs.push_back(std::move(graph)); + } #else if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { VLOG(3) << "use local async mode"; -- GitLab