未验证 提交 81217a94 编写于 作者: L Leo Chen 提交者: GitHub

unify calling cudaSetDevice (#30470)

* unify calling cudaSetDevice

* fix compile
上级 00554b3f
......@@ -94,7 +94,7 @@ class NCCLOpHandleBase : public OpHandleBase {
continue;
}
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id));
platform::SetDeviceId(dev_id);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventCreateWithFlags(
&inter_events_[dev_id], cudaEventDisableTiming));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventCreateWithFlags(
......
......@@ -47,7 +47,7 @@ void OpHandleBase::InitCUDA() {
#ifdef PADDLE_WITH_CUDA
for (auto &p : dev_ctxes_) {
int dev_id = BOOST_GET_CONST(platform::CUDAPlace, p.first).device;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id));
platform::SetDeviceId(dev_id);
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming));
}
......
......@@ -50,7 +50,7 @@ void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank,
nccl_info_.local_rank_ = local_rank;
nccl_info_.my_global_rank_ = global_rank;
nccl_info_.global_ranks_ = ranks;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(local_rank));
platform::SetDeviceId(local_rank);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&(nccl_info_.stream_)));
#endif
return;
......
......@@ -339,7 +339,7 @@ void TensorRTEngine::freshDeviceId() {
platform::errors::OutOfRange(
"Device id %d exceeds the current device count: %d.",
device_id_, count));
cudaSetDevice(device_id_);
platform::SetDeviceId(device_id_);
}
} // namespace tensorrt
......
......@@ -64,7 +64,7 @@ void MultiStreamCompute(float **data, float **second_data,
TEST(Malloc, CUDADeviceContextMultiStream) {
auto place = platform::CUDAPlace(0);
EXPECT_TRUE(cudaSuccess == cudaSetDevice(0));
platform::SetDeviceId(0);
AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float));
EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float));
......@@ -94,7 +94,7 @@ TEST(Malloc, CUDADeviceContextMultiStream) {
TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
auto place = platform::CUDAPlace(0);
EXPECT_TRUE(cudaSuccess == cudaSetDevice(0));
platform::SetDeviceId(0);
AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float));
EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float));
......
......@@ -75,7 +75,7 @@ NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks,
"Expected dev_id >= 0. But received dev_id is %d.", dev_id));
ncclComm_t comm = nullptr;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id));
SetDeviceId(dev_id);
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank));
......
......@@ -226,7 +226,7 @@ void SetDeviceId(int id) {
"Device id must be less than GPU count, "
"but received id is: %d. GPU count is: %d.",
id, GetCUDADeviceCount()));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(id));
PADDLE_RETRY_CUDA_SUCCESS(cudaSetDevice(id));
}
void GpuMemoryUsage(size_t *available, size_t *total) {
......
......@@ -132,7 +132,7 @@ struct NCCLContextMap {
}
VLOG(1) << "init nccl rank:" << rank << ", nranks:" << nranks
<< ", gpu_id:" << gpu_id << ", dev_id:" << order_[i];
PADDLE_RETRY_CUDA_SUCCESS(cudaSetDevice(gpu_id));
SetDeviceId(gpu_id);
PADDLE_RETRY_CUDA_SUCCESS(platform::dynload::ncclCommInitRank(
comms.get() + i, nranks, *nccl_id, rank));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册