diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 36e7f293482899168b00815c73aea9cea461bed5..018e9d19b397bf5edaa2b3fe0e02db00afde1ec6 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -160,29 +160,26 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { }; CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) - : workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) { + : workspace_(nullptr), stream_(stream), place_(place) { PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_)); } CudnnHolder::~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); - if (workspace_ != nullptr) { - paddle::memory::Free(place_, workspace_); - } } void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) { - if (required_workspace_len <= workspace_len_) { + if (required_workspace_len <= WorkspaceSize()) { return; } if (workspace_ != nullptr) { // Maybe someone is using the current workspace PADDLE_ENFORCE(cudaStreamSynchronize(*stream_)); - paddle::memory::Free(place_, workspace_); + workspace_.reset(); } - workspace_ = paddle::memory::Alloc(place_, required_workspace_len); - workspace_len_ = required_workspace_len; + workspace_ = paddle::memory::Alloc(place_, required_workspace_len, + paddle::memory::Allocator::kScratchpad); } CUDADeviceContext::CUDADeviceContext(CUDAPlace place) diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index df248f9bb15591d5015ad01278797ec7e31ef9d1..0e7799833582962395fd25db32a5307458886063 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -16,7 +16,7 @@ limitations under the License. */ #include #include #include - +#include "paddle/fluid/memory/malloc.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cudnn.h" @@ -85,17 +85,32 @@ class CudnnHolder { template void RunFuncImpl(Callback&& cudnn_func, size_t required_workspace_len) { - if (required_workspace_len > workspace_len_) { + if (required_workspace_len > WorkspaceSize()) { ReallocateWorkspace(required_workspace_len); } - cudnn_func(workspace_); + cudnn_func(WorkspacePtr()); + } + + inline void* WorkspacePtr() { + if (workspace_) { + return workspace_->ptr(); + } else { + return nullptr; + } + } + + inline size_t WorkspaceSize() { + if (workspace_) { + return workspace_->size(); + } else { + return 0; + } } std::mutex& Mutex() { return mtx_; } cudnnHandle_t cudnn_handle_; - void* workspace_; - size_t workspace_len_; + std::unique_ptr workspace_; const cudaStream_t* stream_; // not owned; const CUDAPlace place_;