提交 a1d11bb1 编写于 作者: N nhzlx

fix ci bug: cudnn handler in multi card

test=develop
上级 3df7b98a
...@@ -212,6 +212,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { ...@@ -212,6 +212,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
: workspace_(nullptr), stream_(stream), place_(place) { : workspace_(nullptr), stream_(stream), place_(place) {
PADDLE_ENFORCE(cudaSetDevice(place_.device));
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_)); PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
} }
...@@ -233,8 +234,6 @@ void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) { ...@@ -233,8 +234,6 @@ void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) {
paddle::memory::Allocator::kScratchpad); paddle::memory::Allocator::kScratchpad);
} }
std::once_flag CUDADeviceContext::init_cudnn_;
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: place_(place), cudnn_holder_(nullptr) { : place_(place), cudnn_holder_(nullptr) {
CUDADeviceGuard guard(place_.device); CUDADeviceGuard guard(place_.device);
......
...@@ -292,7 +292,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -292,7 +292,7 @@ class CUDADeviceContext : public DeviceContext {
private: private:
CUDAPlace place_; CUDAPlace place_;
static std::once_flag init_cudnn_; mutable std::once_flag init_cudnn_;
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册