diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index 55502eaf4e54957edb1f7c3cfa9616be3f99cf6a..2ba58a6dae5b355a80caca9f976919d1f56af256 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include #include + #include "paddle/fluid/framework/conv_search_cache.h" #include "paddle/fluid/framework/operator_kernel_configs.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h" @@ -101,6 +102,24 @@ inline int MaxBwdFilterAlgos(cudnnHandle_t cudnn_handle) { return max_algos; } +template +void ChooseAlgoByWorkspace(PerfType* perf_results, size_t perf_num, + size_t workspace_byte, AlgoType* algo) { + for (size_t i = 0; i < perf_num; ++i) { + auto result = perf_results[i]; + if (result.status == CUDNN_STATUS_SUCCESS && + result.memory < workspace_byte) { + *algo = result.algo; + VLOG(3) << " algo: " << result.algo << ", time: " << result.time + << " ms, wksp = " << result.memory + << ", status = " << result.status; + return; + } + } + VLOG(3) << "Can not find alog that requires memory < " + << static_cast(workspace_byte) / (1 << 20) << " MB"; +} + template void ChooseAlgo(const std::vector& perf_results, size_t workspace_byte, AlgoType* algo) { @@ -219,7 +238,10 @@ struct SearchAlgorithm { if (workspace_size > workspace_size_limit) { #if CUDNN_VERSION >= 8000 - workspace_size_limit = workspace_size; + // cudnnGetConvolutionForwardAlgorithm is removed in CUDNN-8 + ChooseAlgoByWorkspace(perf_results.get(), + kNUM_CUDNN_FWD_ALGS, + workspace_size_limit, &algo); #else VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " "the workspace size request(" @@ -316,7 +338,6 @@ struct SearchAlgorithm { size_t workspace_size = 0; bool has_got_workspace_size = true; algo_t algo; - #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) auto& dev_ctx = ctx.template device_context(); if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { @@ -362,9 +383,10 @@ struct SearchAlgorithm { if (workspace_size > workspace_size_limit) { has_got_workspace_size = false; #if CUDNN_VERSION >= 8000 - // There is no cudnnGetConvolutionBackwardDataAlgorithm in CUDNN 8 - // version. - workspace_size_limit = workspace_size; + // cudnnGetConvolutionBackwardDataAlgorithm is removed in CUDNN-8 + ChooseAlgoByWorkspace(perf_results.get(), + kNUM_CUDNN_BWD_DATA_ALGS, + workspace_size_limit, &algo); #else VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " "the workspace size request(" @@ -493,6 +515,23 @@ struct SearchAlgorithm { workspace_size = GetWorkspaceSize(args, algo); if (workspace_size > workspace_size_limit) { workspace_size = workspace_size_limit; +#if CUDNN_VERSION >= 8000 + // cudnnGetConvolutionBackwardFilterAlgorithm is removed in CUDNN-8 + ChooseAlgoByWorkspace(perf_results.get(), + kNUM_CUDNN_BWD_FILTER_ALGS, + workspace_size_limit, &algo); +#else + VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " + "the workspace size request(" + << workspace_size << ") exceeds the limit(" + << workspace_size_limit << ")"; + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( + args.handle, args.idesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.wdesc.desc(), + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); +#endif } #else PADDLE_ENFORCE_CUDA_SUCCESS(