未验证 提交 76e92274 编写于 作者: G guochaorong 提交者: GitHub

Merge pull request #13199 from JiayiFeng/fix_CudnnHolder_bug

Fix cudnn holder bug
......@@ -56,5 +56,76 @@ struct RWLock {
};
#endif
class RWLockGuard {
public:
enum Status { kUnLock, kWRLock, kRDLock };
RWLockGuard(RWLock* rw_lock, Status init_status)
: lock_(rw_lock), status_(Status::kUnLock) {
switch (init_status) {
case Status::kRDLock: {
RDLock();
break;
}
case Status::kWRLock: {
WRLock();
break;
}
case Status::kUnLock: {
break;
}
}
}
void WRLock() {
switch (status_) {
case Status::kUnLock: {
lock_->WRLock();
status_ = Status::kWRLock;
break;
}
case Status::kWRLock: {
break;
}
case Status::kRDLock: {
PADDLE_THROW(
"Please unlock read lock first before invoking write lock.");
break;
}
}
}
void RDLock() {
switch (status_) {
case Status::kUnLock: {
lock_->RDLock();
status_ = Status::kRDLock;
break;
}
case Status::kRDLock: {
break;
}
case Status::kWRLock: {
PADDLE_THROW(
"Please unlock write lock first before invoking read lock.");
break;
}
}
}
void UnLock() {
if (status_ != Status::kUnLock) {
lock_->UNLock();
status_ = Status::kUnLock;
}
}
~RWLockGuard() { UnLock(); }
private:
RWLock* lock_;
Status status_;
};
} // namespace framework
} // namespace paddle
......@@ -118,7 +118,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
output_channels / groups * output_height * output_width * output_depth;
int group_offset_filter = filter->numel() / groups;
// ------------------- cudnn conv workspace ---------------------
void* cudnn_workspace = nullptr;
size_t workspace_size_in_bytes; // final workspace to allocate.
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
if (user_workspace_size > 0) {
......@@ -159,20 +158,18 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
"workspace_size to be allocated exceeds the limit");
// Allocate on GPU memory
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
for (int i = 0; i < groups; i++) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
cudnn_filter_desc, filter_data + i * group_offset_filter,
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
&beta, cudnn_output_desc, output_data + i * group_offset_out));
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
cudnn_filter_desc, filter_data + i * group_offset_filter,
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
&beta, cudnn_output_desc, output_data + i * group_offset_out));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
}
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};
......@@ -314,11 +311,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
cudnn_filter_desc, filter_algo, &tmp_size));
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
}
// ------------------- cudnn conv workspace ---------------------
// Already on GPU
void* cudnn_workspace = nullptr;
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
if (input_grad) {
......@@ -326,12 +319,15 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// Because beta is zero, it is unnecessary to reset input_grad.
for (int i = 0; i < groups; i++) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc,
filter_data + i * group_offset_filter, cudnn_output_grad_desc,
output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data + i * group_offset_in));
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc,
filter_data + i * group_offset_filter, cudnn_output_grad_desc,
output_grad_data + i * group_offset_out, cudnn_conv_desc,
data_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_input_desc, input_grad_data + i * group_offset_in));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
}
}
// ------------------- cudnn conv backward filter ---------------------
......@@ -339,16 +335,17 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset filter_grad.
for (int i = 0; i < groups; i++) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
cudnn_output_grad_desc, output_grad_data + i * group_offset_out,
cudnn_conv_desc, filter_algo, cudnn_workspace,
workspace_size_in_bytes, &beta, cudnn_filter_desc,
filter_grad_data + i * group_offset_filter));
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_input_desc,
input_data + i * group_offset_in, cudnn_output_grad_desc,
output_grad_data + i * group_offset_out, cudnn_conv_desc,
filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_filter_desc, filter_grad_data + i * group_offset_filter));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
}
}
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};
......
......@@ -76,7 +76,6 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
conv_desc.descriptor<T>(paddings, strides, dilations);
// ------------------- cudnn conv workspace ---------------------
void* cudnn_workspace = nullptr;
size_t workspace_size_in_bytes; // final workspace to allocate.
size_t workspace_size_limit = kConvCUDNNWorkspaceLimitBytes;
if (user_workspace_size > 0) {
......@@ -100,25 +99,21 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes));
// Allocate on GPU memory
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv transpose forward ---------------------
int input_offset = input->numel() / input->dims()[0] / groups;
int output_offset = output->numel() / output->dims()[0] / groups;
int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f;
for (int g = 0; g < groups; g++) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g,
cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc,
algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_output_desc, output_data + output_offset * g));
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g,
cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc,
algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_output_desc, output_data + output_offset * g));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
}
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};
......@@ -206,11 +201,6 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
std::max(workspace_size_in_bytes, bwd_filter_ws_size);
}
// ------------------- cudnn conv workspace ---------------------
// Already on GPU
void* cudnn_workspace = nullptr;
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data ---------------------
// FIXME(typhoonzero): template type T may not be the same as cudnn call.
int input_offset = input->numel() / input->dims()[0] / groups;
......@@ -222,12 +212,15 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.
for (int g = 0; g < groups; g++) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_filter_desc,
filter_data + filter_offset * g, cudnn_conv_desc, data_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data + input_offset * g));
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_filter_desc,
filter_data + filter_offset * g, cudnn_conv_desc, data_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data + input_offset * g));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
}
}
......@@ -237,17 +230,17 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
// Because beta is zero, it is unnecessary to reset filter_grad.
// Gradient with respect to the filter
for (int g = 0; g < groups; g++) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_input_desc,
input_data + input_offset * g, cudnn_conv_desc, filter_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_filter_desc,
filter_grad_data + filter_offset * g));
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_input_desc,
input_data + input_offset * g, cudnn_conv_desc, filter_algo,
cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_filter_desc, filter_grad_data + filter_offset * g));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
}
}
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};
......
......@@ -16,6 +16,9 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/memory/memory.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/rw_lock.h"
#endif
namespace paddle {
namespace platform {
......@@ -142,7 +145,58 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
mutable unsigned int* semaphore_;
};
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
class CudnnHolder {
public:
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
: workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
}
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) {
std::lock_guard<std::mutex> lock(mtx_);
if (required_workspace_len > workspace_len_) {
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_);
}
~CudnnHolder() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
if (workspace_ != nullptr) {
paddle::memory::Free(place_, workspace_);
}
}
private:
void ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= workspace_len_) {
return;
}
if (workspace_ != nullptr) {
// Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
paddle::memory::Free(place_, workspace_);
}
workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
workspace_len_ = required_workspace_len;
}
cudnnHandle_t cudnn_handle_;
void* workspace_;
size_t workspace_len_;
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
std::mutex mtx_;
};
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: place_(place), cudnn_holder_(nullptr) {
SetDeviceId(place_.device);
compute_capability = GetCUDAComputeCapability(place_.device);
multi_process = GetCUDAMultiProcessors(place_.device);
......@@ -154,10 +208,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
if (dynload::HasCUDNN()) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
} else {
cudnn_handle_ = nullptr;
cudnn_holder_.reset(new CudnnHolder(&stream_, place));
}
}
......@@ -165,9 +216,6 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device);
Wait();
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
if (cudnn_handle_ != nullptr) {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
......@@ -196,7 +244,14 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const {
return cublas_handle_;
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return cudnn_holder_->cudnn_handle();
}
void CUDADeviceContext::RunCudnnFuncWithWorkspace(
const std::function<void(void*)>& cudnn_func, size_t workspace_len) const {
cudnn_holder_->RunFunc(cudnn_func, workspace_len);
}
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
......
......@@ -69,6 +69,7 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
#ifdef PADDLE_WITH_CUDA
class EigenCudaStreamDevice;
class CudnnHolder;
class CUDADeviceContext : public DeviceContext {
public:
......@@ -96,6 +97,11 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const;
/*! \brief Run a cudnn function with the workspace provided by
* CUDADeviceContext */
void RunCudnnFuncWithWorkspace(const std::function<void(void*)>& cudnn_func,
size_t workspace_len) const;
/*! \brief Return cuda stream in the device context. */
cudaStream_t stream() const;
......@@ -111,8 +117,8 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
std::unique_ptr<CudnnHolder> cudnn_holder_;
cudaStream_t stream_;
cudnnHandle_t cudnn_handle_;
cublasHandle_t cublas_handle_;
int compute_capability;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册