未验证 提交 f962bd34 编写于 作者: L Leo Chen 提交者: GitHub

Fix cudnn workspace limit in cudnn-8 (#28611)

上级 90805e2d
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/conv_search_cache.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
......@@ -101,6 +102,24 @@ inline int MaxBwdFilterAlgos(cudnnHandle_t cudnn_handle) {
return max_algos;
}
template <typename PerfType, typename AlgoType>
void ChooseAlgoByWorkspace(PerfType* perf_results, size_t perf_num,
size_t workspace_byte, AlgoType* algo) {
for (size_t i = 0; i < perf_num; ++i) {
auto result = perf_results[i];
if (result.status == CUDNN_STATUS_SUCCESS &&
result.memory < workspace_byte) {
*algo = result.algo;
VLOG(3) << " algo: " << result.algo << ", time: " << result.time
<< " ms, wksp = " << result.memory
<< ", status = " << result.status;
return;
}
}
VLOG(3) << "Can not find alog that requires memory < "
<< static_cast<double>(workspace_byte) / (1 << 20) << " MB";
}
template <typename PerfType, typename AlgoType>
void ChooseAlgo(const std::vector<PerfType>& perf_results,
size_t workspace_byte, AlgoType* algo) {
......@@ -219,7 +238,10 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
if (workspace_size > workspace_size_limit) {
#if CUDNN_VERSION >= 8000
workspace_size_limit = workspace_size;
// cudnnGetConvolutionForwardAlgorithm is removed in CUDNN-8
ChooseAlgoByWorkspace<perf_t, algo_t>(perf_results.get(),
kNUM_CUDNN_FWD_ALGS,
workspace_size_limit, &algo);
#else
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
......@@ -316,7 +338,6 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
size_t workspace_size = 0;
bool has_got_workspace_size = true;
algo_t algo;
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
......@@ -362,9 +383,10 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
if (workspace_size > workspace_size_limit) {
has_got_workspace_size = false;
#if CUDNN_VERSION >= 8000
// There is no cudnnGetConvolutionBackwardDataAlgorithm in CUDNN 8
// version.
workspace_size_limit = workspace_size;
// cudnnGetConvolutionBackwardDataAlgorithm is removed in CUDNN-8
ChooseAlgoByWorkspace<perf_t, algo_t>(perf_results.get(),
kNUM_CUDNN_BWD_DATA_ALGS,
workspace_size_limit, &algo);
#else
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
......@@ -493,6 +515,23 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
workspace_size = GetWorkspaceSize(args, algo);
if (workspace_size > workspace_size_limit) {
workspace_size = workspace_size_limit;
#if CUDNN_VERSION >= 8000
// cudnnGetConvolutionBackwardFilterAlgorithm is removed in CUDNN-8
ChooseAlgoByWorkspace<perf_t, algo_t>(perf_results.get(),
kNUM_CUDNN_BWD_FILTER_ALGS,
workspace_size_limit, &algo);
#else
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
<< workspace_size << ") exceeds the limit("
<< workspace_size_limit << ")";
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
args.handle, args.idesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.wdesc.desc(),
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
#endif
}
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册