未验证 提交 81217a94 编写于 作者: L Leo Chen 提交者: GitHub

unify calling cudaSetDevice (#30470)

* unify calling cudaSetDevice

* fix compile
上级 00554b3f
...@@ -94,7 +94,7 @@ class NCCLOpHandleBase : public OpHandleBase { ...@@ -94,7 +94,7 @@ class NCCLOpHandleBase : public OpHandleBase {
continue; continue;
} }
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id)); platform::SetDeviceId(dev_id);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventCreateWithFlags( PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventCreateWithFlags(
&inter_events_[dev_id], cudaEventDisableTiming)); &inter_events_[dev_id], cudaEventDisableTiming));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventCreateWithFlags( PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventCreateWithFlags(
......
...@@ -47,7 +47,7 @@ void OpHandleBase::InitCUDA() { ...@@ -47,7 +47,7 @@ void OpHandleBase::InitCUDA() {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
for (auto &p : dev_ctxes_) { for (auto &p : dev_ctxes_) {
int dev_id = BOOST_GET_CONST(platform::CUDAPlace, p.first).device; int dev_id = BOOST_GET_CONST(platform::CUDAPlace, p.first).device;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id)); platform::SetDeviceId(dev_id);
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming)); cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming));
} }
......
...@@ -50,7 +50,7 @@ void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank, ...@@ -50,7 +50,7 @@ void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank,
nccl_info_.local_rank_ = local_rank; nccl_info_.local_rank_ = local_rank;
nccl_info_.my_global_rank_ = global_rank; nccl_info_.my_global_rank_ = global_rank;
nccl_info_.global_ranks_ = ranks; nccl_info_.global_ranks_ = ranks;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(local_rank)); platform::SetDeviceId(local_rank);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&(nccl_info_.stream_))); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&(nccl_info_.stream_)));
#endif #endif
return; return;
......
...@@ -339,7 +339,7 @@ void TensorRTEngine::freshDeviceId() { ...@@ -339,7 +339,7 @@ void TensorRTEngine::freshDeviceId() {
platform::errors::OutOfRange( platform::errors::OutOfRange(
"Device id %d exceeds the current device count: %d.", "Device id %d exceeds the current device count: %d.",
device_id_, count)); device_id_, count));
cudaSetDevice(device_id_); platform::SetDeviceId(device_id_);
} }
} // namespace tensorrt } // namespace tensorrt
......
...@@ -64,7 +64,7 @@ void MultiStreamCompute(float **data, float **second_data, ...@@ -64,7 +64,7 @@ void MultiStreamCompute(float **data, float **second_data,
TEST(Malloc, CUDADeviceContextMultiStream) { TEST(Malloc, CUDADeviceContextMultiStream) {
auto place = platform::CUDAPlace(0); auto place = platform::CUDAPlace(0);
EXPECT_TRUE(cudaSuccess == cudaSetDevice(0)); platform::SetDeviceId(0);
AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float)); AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float));
EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float)); EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float));
...@@ -94,7 +94,7 @@ TEST(Malloc, CUDADeviceContextMultiStream) { ...@@ -94,7 +94,7 @@ TEST(Malloc, CUDADeviceContextMultiStream) {
TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) { TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
auto place = platform::CUDAPlace(0); auto place = platform::CUDAPlace(0);
EXPECT_TRUE(cudaSuccess == cudaSetDevice(0)); platform::SetDeviceId(0);
AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float)); AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float));
EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float)); EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float));
......
...@@ -75,7 +75,7 @@ NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, ...@@ -75,7 +75,7 @@ NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks,
"Expected dev_id >= 0. But received dev_id is %d.", dev_id)); "Expected dev_id >= 0. But received dev_id is %d.", dev_id));
ncclComm_t comm = nullptr; ncclComm_t comm = nullptr;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id)); SetDeviceId(dev_id);
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank)); platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank));
......
...@@ -226,7 +226,7 @@ void SetDeviceId(int id) { ...@@ -226,7 +226,7 @@ void SetDeviceId(int id) {
"Device id must be less than GPU count, " "Device id must be less than GPU count, "
"but received id is: %d. GPU count is: %d.", "but received id is: %d. GPU count is: %d.",
id, GetCUDADeviceCount())); id, GetCUDADeviceCount()));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(id)); PADDLE_RETRY_CUDA_SUCCESS(cudaSetDevice(id));
} }
void GpuMemoryUsage(size_t *available, size_t *total) { void GpuMemoryUsage(size_t *available, size_t *total) {
......
...@@ -132,7 +132,7 @@ struct NCCLContextMap { ...@@ -132,7 +132,7 @@ struct NCCLContextMap {
} }
VLOG(1) << "init nccl rank:" << rank << ", nranks:" << nranks VLOG(1) << "init nccl rank:" << rank << ", nranks:" << nranks
<< ", gpu_id:" << gpu_id << ", dev_id:" << order_[i]; << ", gpu_id:" << gpu_id << ", dev_id:" << order_[i];
PADDLE_RETRY_CUDA_SUCCESS(cudaSetDevice(gpu_id)); SetDeviceId(gpu_id);
PADDLE_RETRY_CUDA_SUCCESS(platform::dynload::ncclCommInitRank( PADDLE_RETRY_CUDA_SUCCESS(platform::dynload::ncclCommInitRank(
comms.get() + i, nranks, *nccl_id, rank)); comms.get() + i, nranks, *nccl_id, rank));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册