diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index 6bca299813f166009bc33512e2154907d869cf56..4a0347d07a815f3110d1a78381b0bd006fe2abaa 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -51,7 +51,8 @@ void AllReduceOpHandle::RunImpl() { // FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR, // this is a distributed or inter-process call, find a better way. #ifdef PADDLE_WITH_CUDA - // Find NCCL ID from the global scope. + // All-reduce op_handle can run on the sub-scope, find the nccl id from + // the global scope. if (NoDummyInputSize() == 1 && local_scopes_[0]->FindVar(NCCL_ID_VARNAME) == nullptr) { #else diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index 845c4379e6fb485b21b46e8b7ccf0819223bad41..2377f2c963d25bf52c66cd8c152fefeff48271ea 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -59,7 +59,7 @@ FeedFetchList ParallelSSAGraphExecutor::Run( if (pool_) { run_futures.emplace_back(pool_->enqueue(std::move(call))); } else { - call(); + fetch_datas.emplace_back(std::move(call())); } } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 152b9b2702585cd56a04d7a0ac4b70fc70f6b94d..0042ccaa4f8aa77c7652dc07db1360134fbd935d 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -231,7 +231,7 @@ ParallelExecutor::ParallelExecutor( #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); ncclUniqueId *nccl_id = nullptr; - if (build_strategy.enable_parallel_graph_) { + if (build_strategy.enable_parallel_graph_ && places.size() > 1) { // parallel graph mode should initialize nccl by ncclCommInitRank since // it call nccl operator per device per thread. if (nccl_id_var == nullptr) {