diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 961e8271e18ef8613287c7e266a23e98ed8cdd1a..48002a7620221aad90926ccd3eb89319d6e516b4 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -212,6 +212,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) : workspace_(nullptr), stream_(stream), place_(place) { + PADDLE_ENFORCE(cudaSetDevice(place_.device)); PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_)); } @@ -233,8 +234,6 @@ void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) { paddle::memory::Allocator::kScratchpad); } -std::once_flag CUDADeviceContext::init_cudnn_; - CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place), cudnn_holder_(nullptr) { CUDADeviceGuard guard(place_.device); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 296ecfc6f07e005a8db1ef81732548b004461bee..778f6613bd49dfbc46e8888cd53b1a4de5fe923d 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -292,7 +292,7 @@ class CUDADeviceContext : public DeviceContext { private: CUDAPlace place_; - static std::once_flag init_cudnn_; + mutable std::once_flag init_cudnn_; std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_;