From 151e169eb75a8ee96e0c1e50605fa811cb65acf4 Mon Sep 17 00:00:00 2001 From: guochaorong <32069604+guochaorong@users.noreply.github.com> Date: Tue, 4 Sep 2018 08:14:42 +0800 Subject: [PATCH] Revert "Add CudnnHolder and use it in Conv and ConvTranspose op" --- paddle/fluid/framework/rw_lock.h | 71 ------------------ paddle/fluid/operators/conv_cudnn_op.cu.cc | 57 +++++++------- .../operators/conv_transpose_cudnn_op.cu.cc | 59 ++++++++------- paddle/fluid/platform/device_context.cc | 74 +++---------------- paddle/fluid/platform/device_context.h | 8 +- 5 files changed, 73 insertions(+), 196 deletions(-) diff --git a/paddle/fluid/framework/rw_lock.h b/paddle/fluid/framework/rw_lock.h index da163835e86..a068d3543d9 100644 --- a/paddle/fluid/framework/rw_lock.h +++ b/paddle/fluid/framework/rw_lock.h @@ -56,76 +56,5 @@ struct RWLock { }; #endif -class RWLockGuard { - public: - enum Status { kUnLock, kWRLock, kRDLock }; - - RWLockGuard(RWLock* rw_lock, Status init_status) - : lock_(rw_lock), status_(Status::kUnLock) { - switch (init_status) { - case Status::kRDLock: { - RDLock(); - break; - } - case Status::kWRLock: { - WRLock(); - break; - } - case Status::kUnLock: { - break; - } - } - } - - void WRLock() { - switch (status_) { - case Status::kUnLock: { - lock_->WRLock(); - status_ = Status::kWRLock; - break; - } - case Status::kWRLock: { - break; - } - case Status::kRDLock: { - PADDLE_THROW( - "Please unlock read lock first before invoking write lock."); - break; - } - } - } - - void RDLock() { - switch (status_) { - case Status::kUnLock: { - lock_->RDLock(); - status_ = Status::kRDLock; - break; - } - case Status::kRDLock: { - break; - } - case Status::kWRLock: { - PADDLE_THROW( - "Please unlock write lock first before invoking read lock."); - break; - } - } - } - - void UnLock() { - if (status_ != Status::kUnLock) { - lock_->UNLock(); - status_ = Status::kUnLock; - } - } - - ~RWLockGuard() { UnLock(); } - - private: - RWLock* lock_; - Status status_; -}; - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 4a7a6bcf715..22cbf680c06 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -118,6 +118,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { output_channels / groups * output_height * output_width * output_depth; int group_offset_filter = filter->numel() / groups; // ------------------- cudnn conv workspace --------------------- + void* cudnn_workspace = nullptr; size_t workspace_size_in_bytes; // final workspace to allocate. size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; if (user_workspace_size > 0) { @@ -158,18 +159,20 @@ class CUDNNConvOpKernel : public framework::OpKernel { PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, "workspace_size to be allocated exceeds the limit"); + // Allocate on GPU memory + platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); // ------------------- cudnn conv forward --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; for (int i = 0; i < groups; i++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( - handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, - cudnn_filter_desc, filter_data + i * group_offset_filter, - cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, - &beta, cudnn_output_desc, output_data + i * group_offset_out)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, + cudnn_filter_desc, filter_data + i * group_offset_filter, + cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, + &beta, cudnn_output_desc, output_data + i * group_offset_out)); } + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); } }; @@ -311,7 +314,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { cudnn_filter_desc, filter_algo, &tmp_size)); workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); } - + // ------------------- cudnn conv workspace --------------------- + // Already on GPU + void* cudnn_workspace = nullptr; + platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); // ------------------- cudnn conv backward data --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; if (input_grad) { @@ -319,15 +326,12 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { // Because beta is zero, it is unnecessary to reset input_grad. for (int i = 0; i < groups; i++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( - handle, &alpha, cudnn_filter_desc, - filter_data + i * group_offset_filter, cudnn_output_grad_desc, - output_grad_data + i * group_offset_out, cudnn_conv_desc, - data_algo, cudnn_workspace, workspace_size_in_bytes, &beta, - cudnn_input_desc, input_grad_data + i * group_offset_in)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, cudnn_filter_desc, + filter_data + i * group_offset_filter, cudnn_output_grad_desc, + output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, + input_grad_data + i * group_offset_in)); } } // ------------------- cudnn conv backward filter --------------------- @@ -335,17 +339,16 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset filter_grad. for (int i = 0; i < groups; i++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( - handle, &alpha, cudnn_input_desc, - input_data + i * group_offset_in, cudnn_output_grad_desc, - output_grad_data + i * group_offset_out, cudnn_conv_desc, - filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta, - cudnn_filter_desc, filter_grad_data + i * group_offset_filter)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, + cudnn_output_grad_desc, output_grad_data + i * group_offset_out, + cudnn_conv_desc, filter_algo, cudnn_workspace, + workspace_size_in_bytes, &beta, cudnn_filter_desc, + filter_grad_data + i * group_offset_filter)); } } + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); } }; diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc index 73831611d01..82fff68e755 100644 --- a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc @@ -76,6 +76,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { conv_desc.descriptor(paddings, strides, dilations); // ------------------- cudnn conv workspace --------------------- + void* cudnn_workspace = nullptr; size_t workspace_size_in_bytes; // final workspace to allocate. size_t workspace_size_limit = kConvCUDNNWorkspaceLimitBytes; if (user_workspace_size > 0) { @@ -99,21 +100,25 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, cudnn_output_desc, algo, &workspace_size_in_bytes)); + // Allocate on GPU memory + platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + // ------------------- cudnn conv transpose forward --------------------- int input_offset = input->numel() / input->dims()[0] / groups; int output_offset = output->numel() / output->dims()[0] / groups; int filter_offset = filter->numel() / groups; T alpha = 1.0f, beta = 0.0f; for (int g = 0; g < groups; g++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( - handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g, - cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc, - algo, cudnn_workspace, workspace_size_in_bytes, &beta, - cudnn_output_desc, output_data + output_offset * g)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g, + cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc, + algo, cudnn_workspace, workspace_size_in_bytes, &beta, + cudnn_output_desc, output_data + output_offset * g)); } + + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); } }; @@ -201,6 +206,11 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { std::max(workspace_size_in_bytes, bwd_filter_ws_size); } + // ------------------- cudnn conv workspace --------------------- + // Already on GPU + void* cudnn_workspace = nullptr; + platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); // ------------------- cudnn conv backward data --------------------- // FIXME(typhoonzero): template type T may not be the same as cudnn call. int input_offset = input->numel() / input->dims()[0] / groups; @@ -212,15 +222,12 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset input_grad. for (int g = 0; g < groups; g++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( - handle, &alpha, cudnn_output_desc, - output_grad_data + output_grad_offset * g, cudnn_filter_desc, - filter_data + filter_offset * g, cudnn_conv_desc, data_algo, - cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, - input_grad_data + input_offset * g)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, cudnn_output_desc, + output_grad_data + output_grad_offset * g, cudnn_filter_desc, + filter_data + filter_offset * g, cudnn_conv_desc, data_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, + input_grad_data + input_offset * g)); } } @@ -230,17 +237,17 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { // Because beta is zero, it is unnecessary to reset filter_grad. // Gradient with respect to the filter for (int g = 0; g < groups; g++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( - handle, &alpha, cudnn_output_desc, - output_grad_data + output_grad_offset * g, cudnn_input_desc, - input_data + input_offset * g, cudnn_conv_desc, filter_algo, - cudnn_workspace, workspace_size_in_bytes, &beta, - cudnn_filter_desc, filter_grad_data + filter_offset * g)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, cudnn_output_desc, + output_grad_data + output_grad_offset * g, cudnn_input_desc, + input_data + input_offset * g, cudnn_conv_desc, filter_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_filter_desc, + filter_grad_data + filter_offset * g)); } } + + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); } }; diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 3ec20ad7e5b..2cc26da013f 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -16,9 +16,6 @@ limitations under the License. */ #include #include "paddle/fluid/memory/memory.h" -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/framework/rw_lock.h" -#endif namespace paddle { namespace platform { @@ -145,59 +142,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { mutable unsigned int* semaphore_; }; -class CudnnHolder { - public: - CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) - : workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) { - PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); - PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_)); - } - - cudnnHandle_t cudnn_handle() const { return cudnn_handle_; } - - void RunFunc(const std::function& cudnn_func, - size_t required_workspace_len) { - std::lock_guard lock(mtx_); - if (required_workspace_len > workspace_len_) { - ReallocateWorkspace(required_workspace_len); - } - cudnn_func(workspace_); - } - - ~CudnnHolder() { - PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); - if (workspace_ != nullptr) { - paddle::memory::Free(place_, workspace_); - } - } - - private: - void ReallocateWorkspace(size_t required_workspace_len) { - if (required_workspace_len <= workspace_len_) { - return; - } - void* new_workspace = paddle::memory::Alloc(place_, required_workspace_len); - if (workspace_ != nullptr) { - // Maybe someone is using the current workspace - PADDLE_ENFORCE(cudaStreamSynchronize(*stream_)); - paddle::memory::Free(place_, workspace_); - } - workspace_ = new_workspace; - workspace_len_ = required_workspace_len; - } - - cudnnHandle_t cudnn_handle_; - void* workspace_; - size_t workspace_len_; - - const cudaStream_t* stream_; // not owned; - const CUDAPlace place_; - - std::mutex mtx_; -}; - -CUDADeviceContext::CUDADeviceContext(CUDAPlace place) - : place_(place), cudnn_holder_(nullptr) { +CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { SetDeviceId(place_.device); compute_capability = GetCUDAComputeCapability(place_.device); multi_process = GetCUDAMultiProcessors(place_.device); @@ -209,7 +154,10 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); if (dynload::HasCUDNN()) { - cudnn_holder_.reset(new CudnnHolder(&stream_, place)); + PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); + PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_)); + } else { + cudnn_handle_ = nullptr; } } @@ -217,6 +165,9 @@ CUDADeviceContext::~CUDADeviceContext() { SetDeviceId(place_.device); Wait(); PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); + if (cudnn_handle_ != nullptr) { + PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); + } eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); @@ -245,14 +196,7 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const { return cublas_handle_; } -cudnnHandle_t CUDADeviceContext::cudnn_handle() const { - return cudnn_holder_->cudnn_handle(); -} - -void CUDADeviceContext::RunCudnnFuncWithWorkspace( - const std::function& cudnn_func, size_t workspace_len) const { - cudnn_holder_->RunFunc(cudnn_func, workspace_len); -} +cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 3ed49fc4233..b97dad20db0 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -69,7 +69,6 @@ struct DefaultDeviceContextType { #ifdef PADDLE_WITH_CUDA class EigenCudaStreamDevice; -class CudnnHolder; class CUDADeviceContext : public DeviceContext { public: @@ -97,11 +96,6 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle() const; - /*! \brief Run a cudnn function with the workspace provided by - * CUDADeviceContext */ - void RunCudnnFuncWithWorkspace(const std::function& cudnn_func, - size_t workspace_len) const; - /*! \brief Return cuda stream in the device context. */ cudaStream_t stream() const; @@ -117,8 +111,8 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; - std::unique_ptr cudnn_holder_; cudaStream_t stream_; + cudnnHandle_t cudnn_handle_; cublasHandle_t cublas_handle_; int compute_capability; -- GitLab