From 06936a2ff59ba67f6be0526bf97c26a3cf036b18 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 18 Dec 2018 11:16:14 +0800 Subject: [PATCH] fix 1gpu test=develop --- paddle/fluid/framework/details/all_reduce_op_handle.cc | 3 ++- paddle/fluid/framework/details/parallel_ssa_graph_executor.cc | 2 +- paddle/fluid/framework/parallel_executor.cc | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index 6bca29981..4a0347d07 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -51,7 +51,8 @@ void AllReduceOpHandle::RunImpl() { // FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR, // this is a distributed or inter-process call, find a better way. #ifdef PADDLE_WITH_CUDA - // Find NCCL ID from the global scope. + // All-reduce op_handle can run on the sub-scope, find the nccl id from + // the global scope. if (NoDummyInputSize() == 1 && local_scopes_[0]->FindVar(NCCL_ID_VARNAME) == nullptr) { #else diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index 845c4379e..2377f2c96 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -59,7 +59,7 @@ FeedFetchList ParallelSSAGraphExecutor::Run( if (pool_) { run_futures.emplace_back(pool_->enqueue(std::move(call))); } else { - call(); + fetch_datas.emplace_back(std::move(call())); } } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 152b9b270..0042ccaa4 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -231,7 +231,7 @@ ParallelExecutor::ParallelExecutor( #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); ncclUniqueId *nccl_id = nullptr; - if (build_strategy.enable_parallel_graph_) { + if (build_strategy.enable_parallel_graph_ && places.size() > 1) { // parallel graph mode should initialize nccl by ncclCommInitRank since // it call nccl operator per device per thread. if (nccl_id_var == nullptr) { -- GitLab