提交 f4f4816b 编写于 作者: Q Qiao Longfei

fix gpu error test=develop

上级 12f6b8c3
...@@ -29,6 +29,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( ...@@ -29,6 +29,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
graphs_(std::move(graphs)) { graphs_(std::move(graphs)) {
VLOG(3) << "build AsyncSSAGraphExecutor"; VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
PADDLE_ENFORCE_EQ(graphs_.size, local_scopes_.size());
// set the correct size of thread pool to each device. // set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size() strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
......
...@@ -261,10 +261,21 @@ ParallelExecutor::ParallelExecutor( ...@@ -261,10 +261,21 @@ ParallelExecutor::ParallelExecutor(
// ncclOp // ncclOp
std::vector<std::unique_ptr<ir::Graph>> graphs; std::vector<std::unique_ptr<ir::Graph>> graphs;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
VLOG(3) << "use local async mode";
for (size_t i = 0; i < member_->places_.size(); ++i) {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_,
member_->nccl_ctxs_.get());
graphs.push_back(std::move(graph));
}
} else {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply( std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, member_->local_scopes_, main_program, member_->places_, loss_var_name, member_->local_scopes_,
member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_.get()); member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_.get());
graphs.push_back(std::move(graph)); graphs.push_back(std::move(graph));
}
#else #else
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
VLOG(3) << "use local async mode"; VLOG(3) << "use local async mode";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册