diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 3f8da69fc2f1a0ffeeb55328e1fd25134e57af5c..5c0dcdad3a077efb919bbcae367cc7f48d815486 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -145,9 +145,9 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { class CudnnHolder { public: CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) - : stream_(stream), place_(place), workspace_(nullptr), workspace_len_(0) { + : workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) { PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); - PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_)); + PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_)); } cudnnHandle_t get_cudnn_handle() const { return cudnn_handle_; } @@ -157,14 +157,14 @@ class CudnnHolder { void* new_workspace = paddle::memory::Alloc(place_, required_len); if (workspace_ != nullptr) { // Maybe someone is using the current workspace - PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + PADDLE_ENFORCE(cudaStreamSynchronize(*stream_)); PADDLE_ENFORCE(cudaGetLastError()); paddle::memory::Free(place_, workspace_); } workspace_ = new_workspace; workspace_len_ = required_len; } - return workspace_ + return workspace_; } ~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); } @@ -231,7 +231,7 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_holder_->get_cudnn_handle(); } -void* cudnn_workspace(size_t required_len) const { +void* CUDADeviceContext::cudnn_workspace(size_t required_len) const { return cudnn_holder_->get_workspace(required_len); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 7314d91f3ea6b77e9353279c86288bfc4f428480..5bcd04fa02f18f35e09a9ee4c415ff01d0f5923e 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -69,7 +69,7 @@ struct DefaultDeviceContextType { #ifdef PADDLE_WITH_CUDA class EigenCudaStreamDevice; -class CUDNNHolder; +class CudnnHolder; class CUDADeviceContext : public DeviceContext { public: