diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 2cc26da013f59f5b7ee1747d57baca9c1c0efe2c..01fa9301d61045b8ae5d04dfec84e7a593b38059 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -142,7 +142,43 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { mutable unsigned int* semaphore_; }; -CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { +class CudnnHolder { + public: + CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) + : stream_(stream), place_(place), workspace_(nullptr), workspace_len_(0) { + PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); + PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_)); + } + + cudnnHandle_t get_cudnn_handle() const { return cudnn_handle_; } + + void* get_workspace(size_t required_len) { + if (required_len > workspace_len_) { + void* new_workspace = paddle::memory::Alloc(place_, required_len); + if (workspace_ != nullptr) { + // Maybe someone is using the current workspace + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + PADDLE_ENFORCE(cudaGetLastError()); + paddle::memory::Free(place_, workspace_); + } + workspace_ = new_workspace; + } + return workspace_ + } + + ~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); } + + private: + cudnnHandle_t cudnn_handle_; + void* workspace_; + size_t workspace_len_; + + const cudaStream_t* stream_; // not owned; + const CUDAPlace place_; +}; + +CUDADeviceContext::CUDADeviceContext(CUDAPlace place) + : place_(place), cudnn_holder_(nullptr) { SetDeviceId(place_.device); compute_capability = GetCUDAComputeCapability(place_.device); multi_process = GetCUDAMultiProcessors(place_.device); @@ -154,10 +190,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); if (dynload::HasCUDNN()) { - PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); - PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_)); - } else { - cudnn_handle_ = nullptr; + cudnn_holder_.reset(new CudnnHolder(&stream_, place)); } } @@ -165,9 +198,6 @@ CUDADeviceContext::~CUDADeviceContext() { SetDeviceId(place_.device); Wait(); PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); - if (cudnn_handle_ != nullptr) { - PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); - } eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); @@ -196,7 +226,13 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const { return cublas_handle_; } -cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } +cudnnHandle_t CUDADeviceContext::cudnn_handle() const { + return cudnn_holder_->get_cudnn_handle(); +} + +void* cudnn_workspace(size_t required_len) const { + return cudnn_holder_->get_workspace(required_len); +} cudaStream_t CUDADeviceContext::stream() const { return stream_; } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 88e0383146c1adf2752a362091996bad9cfcce5e..7314d91f3ea6b77e9353279c86288bfc4f428480 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -69,6 +69,7 @@ struct DefaultDeviceContextType { #ifdef PADDLE_WITH_CUDA class EigenCudaStreamDevice; +class CUDNNHolder; class CUDADeviceContext : public DeviceContext { public: @@ -96,6 +97,10 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle() const; + /*! \brief Return a cudnn workspace whose length is greater than the + * 'required_len'. */ + void* cudnn_workspace(size_t required_len) const; + /*! \brief Return cuda stream in the device context. */ cudaStream_t stream() const; @@ -111,8 +116,8 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; + std::unique_ptr cudnn_holder_; cudaStream_t stream_; - cudnnHandle_t cudnn_handle_; cublasHandle_t cublas_handle_; int compute_capability;