未验证 提交 4978db2c 编写于 作者: C chengduo 提交者: GitHub

Remove nccl dep when the number of GPU is 1 (#18158)

* remove nccl dep when the number of GPU is 1
test=develop
上级 25ab23be
...@@ -369,8 +369,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -369,8 +369,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
"Execution which can get better performance," "Execution which can get better performance,"
<< "you can force it off by env FLAGS_enable_parallel_graph=0"; << "you can force it off by env FLAGS_enable_parallel_graph=0";
if (member_->use_cuda_) { if (member_->use_cuda_ && member_->nranks_ > 1) {
// Bcast Parameters to all GPUs
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
member_->InitOrGetNCCLCommunicator(scope, build_strategy); member_->InitOrGetNCCLCommunicator(scope, build_strategy);
...@@ -405,10 +404,11 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -405,10 +404,11 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
} }
return false; return false;
}; };
// Bcast Parameters to all GPUs
if (need_broadcast()) { if (need_broadcast()) {
BCastParamsToDevices(bcast_vars, build_strategy.trainer_id_); BCastParamsToDevices(bcast_vars, build_strategy.trainer_id_);
} }
// Startup Program has been run. All local scopes has correct parameters. // Startup Program has been run. All local scopes has correct parameters.
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
......
...@@ -316,7 +316,9 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -316,7 +316,9 @@ CUDADeviceContext::~CUDADeviceContext() {
eigen_device_.reset(); eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE(cudaStreamDestroy(stream_));
#if !defined(_WIN32) #if !defined(_WIN32)
PADDLE_ENFORCE(dynload::ncclCommDestroy(nccl_comm_)); if (nccl_comm_) {
PADDLE_ENFORCE(dynload::ncclCommDestroy(nccl_comm_));
}
#endif #endif
} }
......
...@@ -223,5 +223,5 @@ if(WITH_DISTRIBUTE) ...@@ -223,5 +223,5 @@ if(WITH_DISTRIBUTE)
endif() endif()
set_tests_properties(test_recordio_reader test_parallel_executor_test_while_train test_parallel_executor_mnist set_tests_properties(test_recordio_reader test_parallel_executor_test_while_train test_parallel_executor_mnist
test_parallel_executor_seresnext test_parallel_executor_crf test_parallel_executor_seresnext test_parallel_executor_crf test_sync_batch_norm_op
PROPERTIES LABELS "RUN_TYPE=DIST") PROPERTIES LABELS "RUN_TYPE=DIST")
...@@ -98,6 +98,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): ...@@ -98,6 +98,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
##################################################################### #####################################################################
# Multi-GPUs, self.N / core.get_cuda_device_count() per GPU # Multi-GPUs, self.N / core.get_cuda_device_count() per GPU
assert core.get_cuda_device_count() > 1
main, startup, outs = self.build_program(place, layout, seed, True, main, startup, outs = self.build_program(place, layout, seed, True,
only_forward) only_forward)
exe = fluid.Executor(place) exe = fluid.Executor(place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册