未验证 提交 c17f9cf2 编写于 作者: S Shang Zhizhou 提交者: GitHub

[bug fix]:Memory increases after adapting the cudnn version to cudnn8 (#27436)

* [bug fix]:Memory increases after adapting the cudnn version to 8

* [bug fix]cudnnGetConvolutionForwardAlgorithm not defined
上级 1e1ae5c5
......@@ -162,7 +162,20 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
workspace_size = GetWorkspaceSize(args, algo);
if (workspace_size > workspace_size_limit) {
#if CUDNN_VERSION >= 8000
workspace_size_limit = workspace_size;
#else
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
<< workspace_size << ") exceeds the limit("
<< workspace_size_limit << ")";
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardAlgorithm(
args.handle, args.idesc.desc(), args.wdesc.desc(),
args.cdesc.desc(), args.odesc.desc(),
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
#endif
}
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
......@@ -291,8 +304,23 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
#endif
workspace_size = GetWorkspaceSize(args, algo);
if (workspace_size > workspace_size_limit) {
workspace_size_limit = workspace_size;
has_got_workspace_size = false;
#if CUDNN_VERSION >= 8000
// There is no cudnnGetConvolutionBackwardDataAlgorithm in CUDNN 8
// version.
workspace_size_limit = workspace_size;
#else
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
<< workspace_size << ") exceeds the limit("
<< workspace_size_limit << ")";
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
args.handle, args.wdesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.idesc.desc(),
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
#endif
}
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
......
......@@ -204,6 +204,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
auto x_dims = framework::vectorize(transformed_input.dims());
auto f_dims = framework::vectorize(filter->dims());
if (!exhaustive_search) {
#if CUDNN_VERSION >= 8000
int perf_count;
int best_algo_idx = 0;
size_t tmp_size = 0;
......@@ -215,13 +216,20 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
cudnn_output_desc, kNUM_CUDNN_FWD_ALGS, &perf_count,
perf_results.get()));
algo = (perf_results.get())[best_algo_idx].algo;
VLOG(3) << "cuDNN forward algo " << algo;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes));
if (workspace_size_in_bytes > workspace_size_limit)
workspace_size_limit = workspace_size_in_bytes;
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
VLOG(3) << "cuDNN forward algo " << algo;
#endif
} else {
std::function<cudnnConvolutionFwdAlgo_t()> search_func =
[&]() -> cudnnConvolutionFwdAlgo_t {
......
......@@ -30,6 +30,10 @@ CUDNN_DNN_ROUTINE_EACH_R2(DEFINE_WRAP);
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DEFINE_WRAP);
#endif
#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R3_LESS_R8
CUDNN_DNN_ROUTINE_EACH_AFTER_R3_LESS_R8(DEFINE_WRAP);
#endif
#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R4
CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DEFINE_WRAP);
#endif
......@@ -54,6 +58,10 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(DEFINE_WRAP);
CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DEFINE_WRAP);
#endif
#ifdef CUDNN_DNN_ROUTINE_EACH_R8
CUDNN_DNN_ROUTINE_EACH_R8(DEFINE_WRAP);
#endif
bool HasCUDNN() {
std::call_once(cudnn_dso_flag,
[]() { cudnn_dso_handle = GetCUDNNDsoHandle(); });
......
......@@ -134,6 +134,7 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3_LESS_R8(__macro) \
__macro(cudnnGetConvolutionBackwardFilterAlgorithm); \
__macro(cudnnGetConvolutionForwardAlgorithm); \
__macro(cudnnGetConvolutionBackwardDataAlgorithm); \
__macro(cudnnSetRNNDescriptor);
CUDNN_DNN_ROUTINE_EACH_AFTER_R3_LESS_R8(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册