未验证 提交 0a96ec69 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix conv v7 workspace size limit error, test=develop (#17902)

上级 4d5f6937
...@@ -165,6 +165,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -165,6 +165,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// TODO(dangqingqing) simplify the following code by SearchAlgorithm in // TODO(dangqingqing) simplify the following code by SearchAlgorithm in
// conv_cudnn_helper.h // conv_cudnn_helper.h
bool has_got_workspace_size = false;
if ((!exhaustive_search) && (!half_float)) { if ((!exhaustive_search) && (!half_float)) {
#if CUDNN_VERSION >= 7001 #if CUDNN_VERSION >= 7001
using perf_t = cudnnConvolutionFwdAlgoPerf_t; using perf_t = cudnnConvolutionFwdAlgoPerf_t;
...@@ -176,11 +177,29 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -176,11 +177,29 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_output_desc, kNUM_CUDNN_FWD_ALGS, &perf_count, cudnn_output_desc, kNUM_CUDNN_FWD_ALGS, &perf_count,
perf_results.get())); perf_results.get()));
algo = (perf_results.get())[best_algo_idx].algo; algo = (perf_results.get())[best_algo_idx].algo;
#else
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( // get workspace size able to allocate
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, cudnn_output_desc, algo, &workspace_size_in_bytes));
workspace_size_limit, &algo));
// NOTE(zjl): cudnnGetConvolutionForwardAlgorithm_v7 cannot limit
// workspace size. If the workspace size found by v7 exceeds the limit,
// we should fallback to non-v7 method to find another algorithm.
if (workspace_size_in_bytes > workspace_size_limit) {
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
<< workspace_size_in_bytes << ") exceeds the limit("
<< workspace_size_limit << ")";
#endif
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
#if CUDNN_VERSION >= 7001
} else {
has_got_workspace_size = true;
}
#endif #endif
VLOG(3) << "cuDNN forward algo " << algo; VLOG(3) << "cuDNN forward algo " << algo;
...@@ -219,10 +238,13 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -219,10 +238,13 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
"cuDNN exhaustive search doesn't support half float."); "cuDNN exhaustive search doesn't support half float.");
} }
// get workspace size able to allocate if (!has_got_workspace_size) {
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( // get workspace size able to allocate
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
cudnn_output_desc, algo, &workspace_size_in_bytes)); handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes));
}
// It is possible for float16 on Volta GPU to allocate more memory than // It is possible for float16 on Volta GPU to allocate more memory than
// the limit because the algo is overrided to use tensor core. // the limit because the algo is overrided to use tensor core.
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
...@@ -366,6 +388,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -366,6 +388,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
auto x_dims = framework::vectorize(input->dims()); auto x_dims = framework::vectorize(input->dims());
auto f_dims = framework::vectorize(filter->dims()); auto f_dims = framework::vectorize(filter->dims());
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
bool has_got_bwd_data_ws_size = false;
if (input_grad) { if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
if (exhaustive_search) { if (exhaustive_search) {
...@@ -431,28 +455,49 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -431,28 +455,49 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) { CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} }
#else
CUDNN_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
handle, cudnn_filter_desc, handle, cudnn_filter_desc, cudnn_output_grad_desc,
// dyDesc: Handle to the previously initialized input cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size));
// differential auto new_workspace_size = std::max(workspace_size_in_bytes, tmp_size);
// tensor descriptor.
cudnn_output_grad_desc, cudnn_conv_desc, if (new_workspace_size > workspace_size_limit) {
// dxDesc: Handle to the previously initialized output tensor VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
// descriptor. "the workspace size request("
cudnn_input_desc, << new_workspace_size << ") exceeds the limit("
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, << workspace_size_limit << ")";
workspace_size_limit, &data_algo)); #endif
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc,
// dyDesc: Handle to the previously initialized input
// differential
// tensor descriptor.
cudnn_output_grad_desc, cudnn_conv_desc,
// dxDesc: Handle to the previously initialized output tensor
// descriptor.
cudnn_input_desc,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &data_algo));
#if CUDNN_VERSION >= 7001
} else {
workspace_size_in_bytes = new_workspace_size;
has_got_bwd_data_ws_size = true;
}
#endif #endif
} }
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( if (!has_got_bwd_data_ws_size) {
handle, cudnn_filter_desc, cudnn_output_grad_desc, CUDNN_ENFORCE(
cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size)); platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); handle, cudnn_filter_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size));
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
}
} }
bool has_got_bwd_filter_ws_size = false;
if (filter_grad) { if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace()); T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
if (exhaustive_search) { if (exhaustive_search) {
...@@ -495,22 +540,45 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -495,22 +540,45 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc, cudnn_filter_desc, kNUM_CUDNN_BWD_FILTER_ALGS, cudnn_conv_desc, cudnn_filter_desc, kNUM_CUDNN_BWD_FILTER_ALGS,
&perf_count, perf_results.get())); &perf_count, perf_results.get()));
filter_algo = (perf_results.get())[best_algo_idx].algo; filter_algo = (perf_results.get())[best_algo_idx].algo;
#else
CUDNN_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle, cudnn_input_desc, cudnn_output_grad_desc, handle, cudnn_input_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_filter_desc, cudnn_conv_desc, cudnn_filter_desc, filter_algo, &tmp_size));
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, auto new_workspace_size = std::max(workspace_size_in_bytes, tmp_size);
workspace_size_limit, &filter_algo));
if (new_workspace_size > workspace_size_limit) {
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
<< new_workspace_size << ") exceeds the limit("
<< workspace_size_limit << ")";
#endif #endif
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_input_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_filter_desc,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &filter_algo));
#if CUDNN_VERSION >= 7001
} else {
workspace_size_in_bytes = new_workspace_size;
has_got_bwd_filter_ws_size = true;
}
#endif
}
if (!has_got_bwd_filter_ws_size) {
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle, cudnn_input_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_filter_desc, filter_algo, &tmp_size));
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
} }
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
cudnn_filter_desc, filter_algo, &tmp_size));
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
} }
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
"workspace_size to be allocated exceeds the limit");
// ------------------- cudnn conv workspace --------------------- // ------------------- cudnn conv workspace ---------------------
if (!cudnn_workspace_ptr) { if (!cudnn_workspace_ptr) {
cudnn_workspace = cudnn_workspace =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册