diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 5ce92ad826741c3cc7240256be7a13db89daada4..0780fb040a6fa0160d6afc3cb6a014c8a6746d71 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -29,6 +29,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( graphs_(std::move(graphs)) { VLOG(3) << "build AsyncSSAGraphExecutor"; PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); + PADDLE_ENFORCE_EQ(graphs_.size, local_scopes_.size()); // set the correct size of thread pool to each device. strategy_.num_threads_ = strategy_.num_threads_ < places_.size() diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ecae729124c21602ed06a72e9c3d71768646605d..cfd6609a4b14fe852fe38ec6300f790b1ad8fae4 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -261,10 +261,21 @@ ParallelExecutor::ParallelExecutor( // ncclOp std::vector> graphs; #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - std::unique_ptr graph = build_strategy.Apply( - main_program, member_->places_, loss_var_name, member_->local_scopes_, - member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_.get()); - graphs.push_back(std::move(graph)); + if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { + VLOG(3) << "use local async mode"; + for (size_t i = 0; i < member_->places_.size(); ++i) { + std::unique_ptr graph = build_strategy.Apply( + main_program, {member_->places_[i]}, loss_var_name, + {member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_, + member_->nccl_ctxs_.get()); + graphs.push_back(std::move(graph)); + } + } else { + std::unique_ptr graph = build_strategy.Apply( + main_program, member_->places_, loss_var_name, member_->local_scopes_, + member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_.get()); + graphs.push_back(std::move(graph)); + } #else if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { VLOG(3) << "use local async mode";