提交 04bfd5c1 编写于 作者: F fengjiayi

add CudnnHolder to manage cudnn_handle and workspace

上级 a615ad46
...@@ -142,7 +142,43 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { ...@@ -142,7 +142,43 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
mutable unsigned int* semaphore_; mutable unsigned int* semaphore_;
}; };
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { class CudnnHolder {
public:
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
: stream_(stream), place_(place), workspace_(nullptr), workspace_len_(0) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
}
cudnnHandle_t get_cudnn_handle() const { return cudnn_handle_; }
void* get_workspace(size_t required_len) {
if (required_len > workspace_len_) {
void* new_workspace = paddle::memory::Alloc(place_, required_len);
if (workspace_ != nullptr) {
// Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaGetLastError());
paddle::memory::Free(place_, workspace_);
}
workspace_ = new_workspace;
}
return workspace_
}
~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); }
private:
cudnnHandle_t cudnn_handle_;
void* workspace_;
size_t workspace_len_;
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
};
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: place_(place), cudnn_holder_(nullptr) {
SetDeviceId(place_.device); SetDeviceId(place_.device);
compute_capability = GetCUDAComputeCapability(place_.device); compute_capability = GetCUDAComputeCapability(place_.device);
multi_process = GetCUDAMultiProcessors(place_.device); multi_process = GetCUDAMultiProcessors(place_.device);
...@@ -154,10 +190,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { ...@@ -154,10 +190,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
if (dynload::HasCUDNN()) { if (dynload::HasCUDNN()) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); cudnn_holder_.reset(new CudnnHolder(&stream_, place));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
} else {
cudnn_handle_ = nullptr;
} }
} }
...@@ -165,9 +198,6 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -165,9 +198,6 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device); SetDeviceId(place_.device);
Wait(); Wait();
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
if (cudnn_handle_ != nullptr) {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
eigen_stream_.reset(); eigen_stream_.reset();
eigen_device_.reset(); eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE(cudaStreamDestroy(stream_));
...@@ -196,7 +226,13 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const { ...@@ -196,7 +226,13 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const {
return cublas_handle_; return cublas_handle_;
} }
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return cudnn_holder_->get_cudnn_handle();
}
void* cudnn_workspace(size_t required_len) const {
return cudnn_holder_->get_workspace(required_len);
}
cudaStream_t CUDADeviceContext::stream() const { return stream_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; }
......
...@@ -69,6 +69,7 @@ struct DefaultDeviceContextType<platform::CPUPlace> { ...@@ -69,6 +69,7 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
class EigenCudaStreamDevice; class EigenCudaStreamDevice;
class CUDNNHolder;
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
...@@ -96,6 +97,10 @@ class CUDADeviceContext : public DeviceContext { ...@@ -96,6 +97,10 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */ /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const; cudnnHandle_t cudnn_handle() const;
/*! \brief Return a cudnn workspace whose length is greater than the
* 'required_len'. */
void* cudnn_workspace(size_t required_len) const;
/*! \brief Return cuda stream in the device context. */ /*! \brief Return cuda stream in the device context. */
cudaStream_t stream() const; cudaStream_t stream() const;
...@@ -111,8 +116,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -111,8 +116,8 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
std::unique_ptr<CudnnHolder> cudnn_holder_;
cudaStream_t stream_; cudaStream_t stream_;
cudnnHandle_t cudnn_handle_;
cublasHandle_t cublas_handle_; cublasHandle_t cublas_handle_;
int compute_capability; int compute_capability;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册